diff --git a/Android/app/build.gradle b/Android/app/build.gradle index 6dedc622..ad822e07 100644 --- a/Android/app/build.gradle +++ b/Android/app/build.gradle @@ -24,14 +24,14 @@ android { storePassword keystoreProperties['storePassword'] } } - compileSdkVersion 33 - buildToolsVersion '33.0.0' + compileSdkVersion ANDROID_COMPILE_SDK_VERSION as int + buildToolsVersion ANDROID_BUILD_TOOLS_VERSION defaultConfig { applicationId "app.intra" // Firebase Crashlytics requires SDK version 16. - minSdkVersion 16 - targetSdkVersion 33 + minSdkVersion ANDROID_MIN_SDK_VERSION as int + targetSdkVersion ANDROID_TARGET_SDK_VERSION as int versionCode 64 versionName "1.3.7" vectorDrawables.useSupportLibrary = true @@ -107,7 +107,7 @@ dependencies { implementation 'com.google.firebase:firebase-crashlytics-ndk:18.2.6' implementation 'com.google.firebase:firebase-config:21.0.1' // For go-tun2socks - implementation project(":tun2socks") + implementation project(path: ':tun2socks', configuration: 'aarBinary') } // For Firebase Analytics diff --git a/Android/gradle.properties b/Android/gradle.properties index 915f0e66..4bd4909d 100644 --- a/Android/gradle.properties +++ b/Android/gradle.properties @@ -17,4 +17,10 @@ # http://www.gradle.org/docs/current/userguide/multi_project_builds.html#sec:decoupled_projects # org.gradle.parallel=true android.enableJetifier=true -android.useAndroidX=true \ No newline at end of file +android.useAndroidX=true + +ANDROID_COMPILE_SDK_VERSION=33 +ANDROID_BUILD_TOOLS_VERSION=33.0.0 + +ANDROID_MIN_SDK_VERSION=16 +ANDROID_TARGET_SDK_VERSION=33 diff --git a/Android/tun2socks/README.md b/Android/tun2socks/README.md deleted file mode 100644 index dbec9f31..00000000 --- a/Android/tun2socks/README.md +++ /dev/null @@ -1,2 +0,0 @@ -This copy of tun2socks.aar is built from https://github.com/Jigsaw-Code/outline-go-tun2socks, at -commit af29b3b614fd7695675004af521cb0f1cc0b3838. It is used here under the Apache 2.0 license. diff --git a/Android/tun2socks/build.gradle b/Android/tun2socks/build.gradle index 1faf7e16..26987183 100644 --- a/Android/tun2socks/build.gradle +++ b/Android/tun2socks/build.gradle @@ -1,2 +1,78 @@ -configurations.maybeCreate("default") -artifacts.add("default", file('tun2socks.aar')) \ No newline at end of file +// We use Android library plugin to get the Android SDK path. +plugins { + id('com.android.library') +} + +android { + compileSdkVersion ANDROID_COMPILE_SDK_VERSION as int + defaultConfig { + minSdkVersion ANDROID_MIN_SDK_VERSION as int + } +} + +configurations { + aarBinary { + canBeConsumed = true + canBeResolved = false + } +} + +def goBuildDir = file("${buildDir}/go") +def outputAAR = file("${buildDir}/tun2socks.aar") + +def srcDir = "${rootDir}/tun2socks/intra" +def srcPackages = [srcDir, + "${srcDir}/android", + "${srcDir}/doh", + "${srcDir}/split", + "${srcDir}/protect"] + +// Make sure that the go build directory exists. +task ensureBuildDir() { + doLast { + goBuildDir.mkdirs() + } +} + +// Install `gomobile` and `gobind` to the build directory, so that the user +// does not need to install them on their system or call `gomobile init`. +task ensureGoMobile(type: Exec, dependsOn: ensureBuildDir) { + // Define outputs so this task will only be executed when they don't exist + outputs.file("${goBuildDir}/gomobile") + outputs.file("${goBuildDir}/gobind") + + commandLine('go', 'build', + '-o', goBuildDir, + 'golang.org/x/mobile/cmd/gomobile', + 'golang.org/x/mobile/cmd/gobind') +} + +// Invoke `gomobile bind` to build from `srcPackages` to `outputAAR`. +// `gomobile` needs the `ANDROID_HOME` environment variable to be set, and the +// parent directory of `gobind` must be in the `PATH` as well. +task gobind(type: Exec, dependsOn: ensureGoMobile) { + // Define inputs and outputs so Gradle will enable incremental builds + inputs.dir(srcDir) + outputs.file(outputAAR) + + workingDir goBuildDir + environment 'ANDROID_HOME', android.sdkDirectory + environment 'PATH', goBuildDir.getPath() + + System.getProperty('path.separator') + + System.getenv('PATH') + + commandLine("${goBuildDir}/gomobile", 'bind', + '-ldflags=-s -w', + '-target=android', + "-androidapi=${android.defaultConfig.minSdk}", + '-o', outputAAR, + *srcPackages) +} + +// AAR file that can be consumed by other projects. For example: +// implementation project(path: ':tun2socks', configuration: 'aarBinary') +artifacts { + aarBinary(outputAAR) { + builtBy(gobind) + } +} diff --git a/Android/tun2socks/intra/android/init.go b/Android/tun2socks/intra/android/init.go new file mode 100644 index 00000000..a91226f8 --- /dev/null +++ b/Android/tun2socks/intra/android/init.go @@ -0,0 +1,27 @@ +// Copyright 2023 Jigsaw Operations LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tun2socks + +import ( + "runtime/debug" + + "github.com/eycorsican/go-tun2socks/common/log" +) + +func init() { + // Conserve memory by increasing garbage collection frequency. + debug.SetGCPercent(10) + log.SetLevel(log.WARN) +} diff --git a/Android/tun2socks/intra/android/tun.go b/Android/tun2socks/intra/android/tun.go new file mode 100644 index 00000000..e3639dc9 --- /dev/null +++ b/Android/tun2socks/intra/android/tun.go @@ -0,0 +1,38 @@ +// Copyright 2023 Jigsaw Operations LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tun2socks + +import ( + "errors" + "os" + + "golang.org/x/sys/unix" +) + +func makeTunFile(fd int) (*os.File, error) { + if fd < 0 { + return nil, errors.New("must provide a valid TUN file descriptor") + } + // Make a copy of `fd` so that os.File's finalizer doesn't close `fd`. + newfd, err := unix.Dup(fd) + if err != nil { + return nil, err + } + file := os.NewFile(uintptr(newfd), "") + if file == nil { + return nil, errors.New("failed to open TUN file descriptor") + } + return file, nil +} diff --git a/Android/tun2socks/intra/android/tun2socks.go b/Android/tun2socks/intra/android/tun2socks.go new file mode 100644 index 00000000..f709e279 --- /dev/null +++ b/Android/tun2socks/intra/android/tun2socks.go @@ -0,0 +1,110 @@ +// Copyright 2023 Jigsaw Operations LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tun2socks + +import ( + "errors" + "io" + "io/fs" + "log" + "os" + "strings" + + "github.com/Jigsaw-Code/Intra/Android/tun2socks/intra" + "github.com/Jigsaw-Code/Intra/Android/tun2socks/intra/doh" + "github.com/Jigsaw-Code/Intra/Android/tun2socks/intra/protect" + "github.com/Jigsaw-Code/outline-sdk/network" +) + +// ConnectIntraTunnel reads packets from a TUN device and applies the Intra routing +// rules. Currently, this only consists of redirecting DNS packets to a specified +// server; all other data flows directly to its destination. +// +// `fd` is the TUN device. The IntraTunnel acquires an additional reference to it, which +// +// is released by IntraTunnel.Disconnect(), so the caller must close `fd` _and_ call +// Disconnect() in order to close the TUN device. +// +// `fakedns` is the DNS server that the system believes it is using, in "host:port" style. +// +// The port is normally 53. +// +// `udpdns` and `tcpdns` are the location of the actual DNS server being used. For DNS +// +// tunneling in Intra, these are typically high-numbered ports on localhost. +// +// `dohdns` is the initial DoH transport. It must not be `nil`. +// `protector` is a wrapper for Android's VpnService.protect() method. +// `eventListener` will be provided with a summary of each TCP and UDP socket when it is closed. +// +// Throws an exception if the TUN file descriptor cannot be opened, or if the tunnel fails to +// connect. +func ConnectIntraTunnel( + fd int, fakedns string, dohdns doh.Transport, protector protect.Protector, eventListener intra.Listener, +) (*intra.Tunnel, error) { + tun, err := makeTunFile(fd) + if err != nil { + return nil, err + } + t, err := intra.NewTunnel(fakedns, dohdns, tun, protector, eventListener) + if err != nil { + return nil, err + } + go copyUntilEOF(t, tun) + go copyUntilEOF(tun, t) + return t, nil +} + +// NewDoHTransport returns a DNSTransport that connects to the specified DoH server. +// `url` is the URL of a DoH server (no template, POST-only). If it is nonempty, it +// +// overrides `udpdns` and `tcpdns`. +// +// `ips` is an optional comma-separated list of IP addresses for the server. (This +// +// wrapper is required because gomobile can't make bindings for []string.) +// +// `protector` is the socket protector to use for all external network activity. +// `auth` will provide a client certificate if required by the TLS server. +// `eventListener` will be notified after each DNS query succeeds or fails. +func NewDoHTransport( + url string, ips string, protector protect.Protector, auth doh.ClientAuth, eventListener intra.Listener, +) (doh.Transport, error) { + split := []string{} + if len(ips) > 0 { + split = strings.Split(ips, ",") + } + dialer := protect.MakeDialer(protector) + return doh.NewTransport(url, split, dialer, auth, eventListener) +} + +func copyUntilEOF(dst, src io.ReadWriteCloser) { + log.Printf("[debug] start relaying traffic [%s] -> [%s]", src, dst) + defer log.Printf("[debug] stop relaying traffic [%s] -> [%s]", src, dst) + + const commonMTU = 1500 + buf := make([]byte, commonMTU) + defer dst.Close() + for { + _, err := io.CopyBuffer(dst, src, buf) + if err == nil || isErrClosed(err) { + return + } + } +} + +func isErrClosed(err error) bool { + return errors.Is(err, os.ErrClosed) || errors.Is(err, fs.ErrClosed) || errors.Is(err, network.ErrClosed) +} diff --git a/Android/tun2socks/intra/doh/atomic.go b/Android/tun2socks/intra/doh/atomic.go new file mode 100644 index 00000000..c0a179c9 --- /dev/null +++ b/Android/tun2socks/intra/doh/atomic.go @@ -0,0 +1,38 @@ +// Copyright 2023 Jigsaw Operations LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package doh + +import ( + "sync/atomic" +) + +// Atomic is atomic.Value, specialized for doh.Transport. +type Atomic struct { + v atomic.Value +} + +// Store a DNSTransport. d must not be nil. +func (a *Atomic) Store(t Transport) { + a.v.Store(t) +} + +// Load the DNSTransport, or nil if it has not been stored. +func (a *Atomic) Load() Transport { + v := a.v.Load() + if v == nil { + return nil + } + return v.(Transport) +} diff --git a/Android/tun2socks/intra/doh/client_auth.go b/Android/tun2socks/intra/doh/client_auth.go new file mode 100644 index 00000000..e7e5aa97 --- /dev/null +++ b/Android/tun2socks/intra/doh/client_auth.go @@ -0,0 +1,116 @@ +// Copyright 2023 Jigsaw Operations LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package doh + +import ( + "crypto" + "crypto/ecdsa" + "crypto/tls" + "crypto/x509" + "errors" + "io" + + "github.com/eycorsican/go-tun2socks/common/log" +) + +// ClientAuth interface for providing TLS certificates and signatures. +type ClientAuth interface { + // GetClientCertificate returns the client certificate (if any). + // May block as the first call may cause certificates to load. + // Returns a DER encoded X.509 client certificate. + GetClientCertificate() []byte + // GetIntermediateCertificate returns the chaining certificate (if any). + // It does not block or cause certificates to load. + // Returns a DER encoded X.509 certificate. + GetIntermediateCertificate() []byte + // Request a signature on a digest. + Sign(digest []byte) []byte +} + +// clientAuthWrapper manages certificate loading and usage during TLS handshakes. +// Implements crypto.Signer. +type clientAuthWrapper struct { + signer ClientAuth +} + +// GetClientCertificate returns the client certificate chain as a tls.Certificate. +// Returns an empty Certificate on failure, permitting the handshake to +// continue without authentication. +// Implements tls.Config GetClientCertificate(). +func (ca *clientAuthWrapper) GetClientCertificate( + info *tls.CertificateRequestInfo) (*tls.Certificate, error) { + if ca.signer == nil { + log.Warnf("Client certificate requested but not supported") + return &tls.Certificate{}, nil + } + cert := ca.signer.GetClientCertificate() + if cert == nil { + log.Warnf("Unable to fetch client certificate") + return &tls.Certificate{}, nil + } + chain := [][]byte{cert} + intermediate := ca.signer.GetIntermediateCertificate() + if intermediate != nil { + chain = append(chain, intermediate) + } + leaf, err := x509.ParseCertificate(cert) + if err != nil { + log.Warnf("Unable to parse client certificate: %v", err) + return &tls.Certificate{}, nil + } + _, isECDSA := leaf.PublicKey.(*ecdsa.PublicKey) + if !isECDSA { + // RSA-PSS and RSA-SSA both need explicit signature generation support. + log.Warnf("Only ECDSA client certificates are supported") + return &tls.Certificate{}, nil + } + return &tls.Certificate{ + Certificate: chain, + PrivateKey: ca, + Leaf: leaf, + }, nil +} + +// Public returns the public key for the client certificate. +func (ca *clientAuthWrapper) Public() crypto.PublicKey { + if ca.signer == nil { + return nil + } + cert := ca.signer.GetClientCertificate() + leaf, err := x509.ParseCertificate(cert) + if err != nil { + log.Warnf("Unable to parse client certificate: %v", err) + return nil + } + return leaf.PublicKey +} + +// Sign a digest. +func (ca *clientAuthWrapper) Sign(rand io.Reader, digest []byte, opts crypto.SignerOpts) ([]byte, error) { + if ca.signer == nil { + return nil, errors.New("no client certificate") + } + signature := ca.signer.Sign(digest) + if signature == nil { + return nil, errors.New("failed to create signature") + } + return signature, nil +} + +func newClientAuthWrapper(signer ClientAuth) clientAuthWrapper { + return clientAuthWrapper{ + signer: signer, + } +} diff --git a/Android/tun2socks/intra/doh/client_auth_test.go b/Android/tun2socks/intra/doh/client_auth_test.go new file mode 100644 index 00000000..9dfaefd3 --- /dev/null +++ b/Android/tun2socks/intra/doh/client_auth_test.go @@ -0,0 +1,338 @@ +// Copyright 2023 Jigsaw Operations LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package doh + +import ( + "bytes" + "crypto" + "crypto/ecdsa" + "crypto/rand" + "crypto/sha256" + "crypto/tls" + "crypto/x509" + "encoding/pem" + "fmt" + "testing" +) + +// PEM encoded test leaf certificate with ECDSA public key. +var ecCertificate string = `-----BEGIN CERTIFICATE----- +MIIBpTCCAQ4CAiAAMA0GCSqGSIb3DQEBCwUAMD4xCzAJBgNVBAYTAlVTMQswCQYD +VQQIDAJDQTEWMBQGA1UEBwwNTW91bnRhaW4gVmlldzEKMAgGA1UECgwBWDAeFw0y +MDExMDQwNTU2MTZaFw0zMDExMDIwNTU2MTZaMD4xCzAJBgNVBAYTAlVTMQswCQYD +VQQIDAJDQTEWMBQGA1UEBwwNTW91bnRhaW4gVmlldzEKMAgGA1UECgwBWDBZMBMG +ByqGSM49AgEGCCqGSM49AwEHA0IABNFVWlOs0tnaLgiutLbPISCd5Fn9UJz6oDen +prTOrHz11PiO/XiqwpJY8yO72QappL/7RYV+uw9hJfU+YOE3tZQwDQYJKoZIhvcN +AQELBQADgYEAdy6CNPvIA7DrS6WrN7N4ZjHjeUtjj2w8n5abTHhvANEvIHI0DARI +AoJJWp4Pe41mzFhROzo+U/ofC2b+ukA8sYqoio4QUxlSW3HkzUAR4HZMi8Risvo3 +OxSR9Lw/mGvZrJ8xr070EwnsD+cCZLfYQ0mSKDM9uPfI3YrgCVKyUwE= +-----END CERTIFICATE-----` + +// PKCS8 encoded test ECDSA private key. +var ecKey string = `-----BEGIN PRIVATE KEY----- +MIGHAgEAMBMGByqGSM49AgEGCCqGSM49AwEHBG0wawIBAQQgIlI6NB+skAYL36XP +JvE+x5Nlbn0wvw2hlSqIqADiZhShRANCAATRVVpTrNLZ2i4IrrS2zyEgneRZ/VCc ++qA3p6a0zqx89dT4jv14qsKSWPMju9kGqaS/+0WFfrsPYSX1PmDhN7WU +-----END PRIVATE KEY-----` + +// PEM encoded test leaf certificate with RSA public key. +// Doubles as an intermediate depending on the test. +var rsaCertificate string = `-----BEGIN CERTIFICATE----- +MIICWDCCAcGgAwIBAgIUS36guwZMKNO0ADReGLi0cZq8fOowDQYJKoZIhvcNAQEL +BQAwPjELMAkGA1UEBhMCVVMxCzAJBgNVBAgMAkNBMRYwFAYDVQQHDA1Nb3VudGFp +biBWaWV3MQowCAYDVQQKDAFYMB4XDTIwMTEwNDA1NDgyNVoXDTMwMTEwMjA1NDgy +NVowPjELMAkGA1UEBhMCVVMxCzAJBgNVBAgMAkNBMRYwFAYDVQQHDA1Nb3VudGFp +biBWaWV3MQowCAYDVQQKDAFYMIGfMA0GCSqGSIb3DQEBAQUAA4GNADCBiQKBgQDd +eznqVu1Rn0m8KR4mX/qVv6uytzZ+juqW5VD55D+w9N6JryPpFHPi4VIm8PKLXp3X +GvY9mc8r+0Ow1qJZYoc/X0Na1c79bv9xwbD3aK28FlAs1+cmyesaFhCWa0bYAvcy +mqQGYhObEWb46E5AANV82CitDE9C1aXRT4SvkLnc6wIDAQABo1MwUTAdBgNVHQ4E +FgQUnUib8BhOHqjq9+gqPQ+ePyEW9zwwHwYDVR0jBBgwFoAUnUib8BhOHqjq9+gq +PQ+ePyEW9zwwDwYDVR0TAQH/BAUwAwEB/zANBgkqhkiG9w0BAQsFAAOBgQAx/uZG +Gmb5w/u4UkdH7wnoOUNx6GwdraqtQWnFaXb87PmuVAjBwSAnzes2mlp/Vbcd6tYs +pPuHrxOcWgw/aRV6rK3vJZIH3DGvy1pNphGgegEcG88nrUCDcQqPLxvPJ8bmbaee +Tf+l5U2OHC3Yifb4FDOv47kGmq5VeWiYdp60/A== +-----END CERTIFICATE-----` + +// PKCS8 encoded test RSA private key. +var rsaKey string = `-----BEGIN PRIVATE KEY----- +MIICeAIBADANBgkqhkiG9w0BAQEFAASCAmIwggJeAgEAAoGBAN17OepW7VGfSbwp +HiZf+pW/q7K3Nn6O6pblUPnkP7D03omvI+kUc+LhUibw8otendca9j2Zzyv7Q7DW +ollihz9fQ1rVzv1u/3HBsPdorbwWUCzX5ybJ6xoWEJZrRtgC9zKapAZiE5sRZvjo +TkAA1XzYKK0MT0LVpdFPhK+QudzrAgMBAAECgYEAoCdhI8Ej7qe+S993u8wfiXWG +FL9DGpUBsYe03F5eZ/lJikopL3voqKDCJQKKgJk0jb0jXjwAgQ86TX+G+hezL5jp +xOOfMmTYgMwnUuFYN1gHAd+TnYB9G1qSQr9TOw3K9Rf4q2x09GhLP75qdr+qzmIR +YGle5ZSP0LqKNkpGNUECQQD+6CxOO8+knnzIFvqkUyNDVFR5ALRNpb53TGVITNf3 +ysT32oJ75ButA0l4q/jsL+MeLLvrHkJOHN+ydLaZOUkbAkEA3m5cICisW9lsT+Rj +glXykkbj3Ougldy7rhPivAaS7clk8cl8cDcIvHna1mDlhSanUu/s4TFEXBLnSzee +XLNIcQJBAJ0n3TD6lSEkCUB/UlX/X81B77aOZZs9pXj9o6/4mGoQHHHGyQ3C7AE1 +9pUsSZKsT3UqFU124WAxUwU+CdnbxKMCQB/QrUC0UKL6oHF0+37DCGU/2ovY8Ck/ +X2Dw2zeFwTJd4iBrb28lkAxVaaXMSkgXVUuZoco8H8kDsy2hEPe1dSECQQCPw5Yg +2gdmdpUk+QetqqhSuuIDwILHU9m3CoX3rY+njaR5LOWDz3utC9Ogo+4wdIMamP/o +2SAWPAZPqDUbtqGH +-----END PRIVATE KEY-----` + +// fakeClientAuth implements the ClientAuth interface for testing. +type fakeClientAuth struct { + certificate *x509.Certificate + intermediate *x509.Certificate + key crypto.PrivateKey +} + +func (ca *fakeClientAuth) GetClientCertificate() []byte { + if ca.certificate == nil { + // Interface uses nil for errors to support binding. + return nil + } + return ca.certificate.Raw +} + +func (ca *fakeClientAuth) GetIntermediateCertificate() []byte { + if ca.intermediate == nil { + return nil + } + return ca.intermediate.Raw +} + +func (ca *fakeClientAuth) Sign(digest []byte) []byte { + if ca.key == nil { + return nil + } + if k, isECDSA := ca.key.(*ecdsa.PrivateKey); isECDSA { + signature, err := ecdsa.SignASN1(rand.Reader, k, digest) + if err != nil { + return nil + } + return signature + } + // Unsupported key type + return nil +} + +func newFakeClientAuth(certificate, intermediate, key []byte) (*fakeClientAuth, error) { + ca := &fakeClientAuth{} + if certificate != nil { + certX509, err := x509.ParseCertificate(certificate) + if err != nil { + return nil, fmt.Errorf("certificate: %v", err) + } + ca.certificate = certX509 + } + if intermediate != nil { + intX509, err := x509.ParseCertificate(intermediate) + if err != nil { + return nil, fmt.Errorf("intermediate: %v", err) + } + ca.intermediate = intX509 + } + if key != nil { + key, err := x509.ParsePKCS8PrivateKey(key) + if err != nil { + return nil, fmt.Errorf("private key: %v", err) + } + ca.key = key + } + return ca, nil +} + +func newCertificateRequestInfo() *tls.CertificateRequestInfo { + return &tls.CertificateRequestInfo{ + Version: tls.VersionTLS13, + } +} + +func newToBeSigned(message []byte) ([]byte, crypto.SignerOpts) { + digest := sha256.Sum256(message) + opts := crypto.SignerOpts(crypto.SHA256) + return digest[:], opts +} + +// Simulate a TLS handshake that requires a client cert and signature. +func TestSign(t *testing.T) { + certDer, _ := pem.Decode([]byte(ecCertificate)) + keyDer, _ := pem.Decode([]byte(ecKey)) + intDer, _ := pem.Decode([]byte(rsaCertificate)) + ca, err := newFakeClientAuth(certDer.Bytes, intDer.Bytes, keyDer.Bytes) + if err != nil { + t.Fatal(err) + } + wrapper := newClientAuthWrapper(ca) + // TLS stack requests the client cert. + req := newCertificateRequestInfo() + cert, err := wrapper.GetClientCertificate(req) + if err != nil { + t.Fatal("Expected to get a client certificate") + } + if cert == nil { + // From the crypto.tls docs: + // If GetClientCertificate returns an error, the handshake will + // be aborted and that error will be returned. Otherwise + // GetClientCertificate must return a non-nil Certificate. + t.Error("GetClientCertificate must return a non-nil certificate") + } + if len(cert.Certificate) != 2 { + t.Fatal("Certificate chain is the wrong length") + } + if !bytes.Equal(cert.Certificate[0], certDer.Bytes) { + t.Error("Problem with certificate chain[0]") + } + if !bytes.Equal(cert.Certificate[1], intDer.Bytes) { + t.Error("Problem with certificate chain[1]") + } + // TLS stack requests a signature. + digest, opts := newToBeSigned([]byte("hello world")) + signature, err := wrapper.Sign(rand.Reader, digest, opts) + if err != nil { + t.Fatal(err) + } + // Verify the signature. + pub, ok := wrapper.Public().(*ecdsa.PublicKey) + if !ok { + t.Fatal("Expected public key to be ECDSA") + } + if !ecdsa.VerifyASN1(pub, digest, signature) { + t.Fatal("Problem verifying signature") + } +} + +// Simulate a client that does not use an intermediate certificate. +func TestSignNoIntermediate(t *testing.T) { + certDer, _ := pem.Decode([]byte(ecCertificate)) + keyDer, _ := pem.Decode([]byte(ecKey)) + ca, err := newFakeClientAuth(certDer.Bytes, nil, keyDer.Bytes) + if err != nil { + t.Fatal(err) + } + wrapper := newClientAuthWrapper(ca) + // TLS stack requests a client cert. + req := newCertificateRequestInfo() + cert, err := wrapper.GetClientCertificate(req) + if err != nil { + t.Error("Expected to get a client certificate") + } + if cert == nil { + t.Error("GetClientCertificate must return a non-nil certificate") + } + if len(cert.Certificate) != 1 { + t.Error("Certificate chain is the wrong length") + } + if !bytes.Equal(cert.Certificate[0], certDer.Bytes) { + t.Error("Problem with certificate chain[0]") + } + // TLS stack requests a signature + digest, opts := newToBeSigned([]byte("hello world")) + signature, err := wrapper.Sign(rand.Reader, digest, opts) + if err != nil { + t.Error(err) + } + // Verify the signature. + pub, ok := wrapper.Public().(*ecdsa.PublicKey) + if !ok { + t.Error("Expected public key to be ECDSA") + } + if !ecdsa.VerifyASN1(pub, digest, signature) { + t.Error("Problem verifying signature") + } +} + +// Simulate a client that does not have a certificate. +func TestNoAuth(t *testing.T) { + ca, err := newFakeClientAuth(nil, nil, nil) + if err != nil { + t.Fatal(err) + } + wrapper := newClientAuthWrapper(ca) + // TLS stack requests a client cert. + req := newCertificateRequestInfo() + cert, err := wrapper.GetClientCertificate(req) + if err != nil { + t.Error("Expected to get a client certificate") + } + if cert == nil { + t.Error("GetClientCertificate must return a non-nil certificate") + } + if len(cert.Certificate) != 0 { + t.Error("Certificate chain is the wrong length") + } + // TLS stack requests a signature. This should not happen in real life + // because cert.Certificate is empty. + public := wrapper.Public() + if public != nil { + t.Error("Expected public to be nil") + } + digest, opts := newToBeSigned([]byte("hello world")) + _, err = wrapper.Sign(rand.Reader, digest, opts) + if err == nil { + t.Error("Expected Sign() to fail") + } +} + +// Simulate a client that has an RSA certificate. +func TestRSACertificate(t *testing.T) { + certDer, _ := pem.Decode([]byte(rsaCertificate)) + keyDer, _ := pem.Decode([]byte(rsaKey)) + ca, err := newFakeClientAuth(certDer.Bytes, nil, keyDer.Bytes) + if err != nil { + t.Fatal(err) + } + wrapper := newClientAuthWrapper(ca) + // TLS stack requests a client cert. We should not return one because + // we don't support RSA. + req := newCertificateRequestInfo() + cert, err := wrapper.GetClientCertificate(req) + if err != nil { + t.Error("Expected to get a client certificate") + } + if cert == nil { + t.Error("GetClientCertificate must return a non-nil certificate") + } + if len(cert.Certificate) != 0 { + t.Error("Unexpectedly loaded an RSA certificate") + } + // TLS stack requests a signature. This should not happen in real life + // because cert.Certificate is empty. + digest, opts := newToBeSigned([]byte("hello world")) + _, err = wrapper.Sign(rand.Reader, digest, opts) + if err == nil { + t.Error("Expected Sign() to fail") + } +} + +// Simulate a nil loader. +func TestNilLoader(t *testing.T) { + wrapper := newClientAuthWrapper(nil) + // TLS stack requests the client cert. + req := newCertificateRequestInfo() + cert, err := wrapper.GetClientCertificate(req) + if err != nil { + t.Fatal(err) + } + if cert == nil { + // From the crypto.tls docs: + // If GetClientCertificate returns an error, the handshake will + // be aborted and that error will be returned. Otherwise + // GetClientCertificate must return a non-nil Certificate. + t.Error("GetClientCertificate must return a non-nil certificate") + } + if len(cert.Certificate) != 0 { + t.Fatal("Expected an empty certificate chain") + } + // TLS stack requests a signature. This should not happen in real life + // because cert.Certificate is empty. + digest, opts := newToBeSigned([]byte("hello world")) + _, err = wrapper.Sign(rand.Reader, digest, opts) + if err == nil { + t.Error("Expected Sign() to fail") + } +} diff --git a/Android/tun2socks/intra/doh/doh.go b/Android/tun2socks/intra/doh/doh.go new file mode 100644 index 00000000..082ce656 --- /dev/null +++ b/Android/tun2socks/intra/doh/doh.go @@ -0,0 +1,566 @@ +// Copyright 2023 Jigsaw Operations LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package doh + +import ( + "bytes" + "crypto/tls" + "encoding/binary" + "errors" + "fmt" + "io" + "io/ioutil" + "math" + "net" + "net/http" + "net/http/httptrace" + "net/textproto" + "net/url" + "strconv" + "sync" + "time" + + "github.com/Jigsaw-Code/Intra/Android/tun2socks/intra/doh/ipmap" + "github.com/Jigsaw-Code/Intra/Android/tun2socks/intra/split" + "github.com/eycorsican/go-tun2socks/common/log" + "golang.org/x/net/dns/dnsmessage" +) + +const ( + // Complete : Transaction completed successfully + Complete = iota + // SendFailed : Failed to send query + SendFailed + // HTTPError : Got a non-200 HTTP status + HTTPError + // BadQuery : Malformed input + BadQuery + // BadResponse : Response was invalid + BadResponse + // InternalError : This should never happen + InternalError +) + +// If the server sends an invalid reply, we start a "servfail hangover" +// of this duration, during which all queries are rejected. +// This rate-limits queries to misconfigured servers (e.g. wrong URL). +const hangoverDuration = 10 * time.Second + +// Summary is a summary of a DNS transaction, reported when it is complete. +type Summary struct { + Latency float64 // Response (or failure) latency in seconds + Query []byte + Response []byte + Server string + Status int + HTTPStatus int // Zero unless Status is Complete or HTTPError +} + +// A Token is an opaque handle used to match responses to queries. +type Token interface{} + +// Listener receives Summaries. +type Listener interface { + OnQuery(url string) Token + OnResponse(Token, *Summary) +} + +// Transport represents a DNS query transport. This interface is exported by gobind, +// so it has to be very simple. +type Transport interface { + // Given a DNS query (including ID), returns a DNS response with matching + // ID, or an error if no response was received. The error may be accompanied + // by a SERVFAIL response if appropriate. + Query(q []byte) ([]byte, error) + // Return the server URL used to initialize this transport. + GetURL() string +} + +// TODO: Keep a context here so that queries can be canceled. +type transport struct { + Transport + url string + hostname string + port int + ips ipmap.IPMap + client http.Client + dialer *net.Dialer + listener Listener + hangoverLock sync.RWMutex + hangoverExpiration time.Time +} + +// Wait up to three seconds for the TCP handshake to complete. +const tcpTimeout time.Duration = 3 * time.Second + +func (t *transport) dial(network, addr string) (net.Conn, error) { + log.Debugf("Dialing %s", addr) + domain, portStr, err := net.SplitHostPort(addr) + if err != nil { + return nil, err + } + port, err := strconv.Atoi(portStr) + if err != nil { + return nil, err + } + + tcpaddr := func(ip net.IP) *net.TCPAddr { + return &net.TCPAddr{IP: ip, Port: port} + } + + // TODO: Improve IP fallback strategy with parallelism and Happy Eyeballs. + var conn net.Conn + ips := t.ips.Get(domain) + confirmed := ips.Confirmed() + if confirmed != nil { + log.Debugf("Trying confirmed IP %s for addr %s", confirmed.String(), addr) + if conn, err = split.DialWithSplitRetry(t.dialer, tcpaddr(confirmed), nil); err == nil { + log.Infof("Confirmed IP %s worked", confirmed.String()) + return conn, nil + } + log.Debugf("Confirmed IP %s failed with err %v", confirmed.String(), err) + ips.Disconfirm(confirmed) + } + + log.Debugf("Trying all IPs") + for _, ip := range ips.GetAll() { + if ip.Equal(confirmed) { + // Don't try this IP twice. + continue + } + if conn, err = split.DialWithSplitRetry(t.dialer, tcpaddr(ip), nil); err == nil { + log.Infof("Found working IP: %s", ip.String()) + return conn, nil + } + } + return nil, err +} + +// NewTransport returns a DoH DNSTransport, ready for use. +// This is a POST-only DoH implementation, so the DoH template should be a URL. +// `rawurl` is the DoH template in string form. +// `addrs` is a list of domains or IP addresses to use as fallback, if the hostname +// +// lookup fails or returns non-working addresses. +// +// `dialer` is the dialer that the transport will use. The transport will modify the dialer's +// +// timeout but will not mutate it otherwise. +// +// `auth` will provide a client certificate if required by the TLS server. +// `listener` will receive the status of each DNS query when it is complete. +func NewTransport(rawurl string, addrs []string, dialer *net.Dialer, auth ClientAuth, listener Listener) (Transport, error) { + if dialer == nil { + dialer = &net.Dialer{} + } + parsedurl, err := url.Parse(rawurl) + if err != nil { + return nil, err + } + if parsedurl.Scheme != "https" { + return nil, fmt.Errorf("Bad scheme: %s", parsedurl.Scheme) + } + // Resolve the hostname and put those addresses first. + portStr := parsedurl.Port() + var port int + if len(portStr) > 0 { + port, err = strconv.Atoi(portStr) + if err != nil { + return nil, err + } + } else { + port = 443 + } + + t := &transport{ + url: rawurl, + hostname: parsedurl.Hostname(), + port: port, + listener: listener, + dialer: dialer, + ips: ipmap.NewIPMap(dialer.Resolver), + } + ips := t.ips.Get(t.hostname) + for _, addr := range addrs { + ips.Add(addr) + } + if ips.Empty() { + return nil, fmt.Errorf("No IP addresses for %s", t.hostname) + } + + // Supply a client certificate during TLS handshakes. + var tlsconfig *tls.Config + if auth != nil { + signer := newClientAuthWrapper(auth) + tlsconfig = &tls.Config{ + GetClientCertificate: signer.GetClientCertificate, + } + } + + // Override the dial function. + t.client.Transport = &http.Transport{ + Dial: t.dial, + ForceAttemptHTTP2: true, + TLSHandshakeTimeout: 10 * time.Second, + ResponseHeaderTimeout: 20 * time.Second, // Same value as Android DNS-over-TLS + TLSClientConfig: tlsconfig, + } + return t, nil +} + +type queryError struct { + status int + err error +} + +func (e *queryError) Error() string { + return e.err.Error() +} + +func (e *queryError) Unwrap() error { + return e.err +} + +type httpError struct { + status int +} + +func (e *httpError) Error() string { + return fmt.Sprintf("HTTP request failed: %d", e.status) +} + +// Given a raw DNS query (including the query ID), this function sends the +// query. If the query is successful, it returns the response and a nil qerr. Otherwise, +// it returns a SERVFAIL response and a qerr with a status value indicating the cause. +// Independent of the query's success or failure, this function also returns the +// address of the server on a best-effort basis, or nil if the address could not +// be determined. +func (t *transport) doQuery(q []byte) (response []byte, server *net.TCPAddr, qerr *queryError) { + if len(q) < 2 { + qerr = &queryError{BadQuery, fmt.Errorf("Query length is %d", len(q))} + return + } + + t.hangoverLock.RLock() + inHangover := time.Now().Before(t.hangoverExpiration) + t.hangoverLock.RUnlock() + if inHangover { + response = tryServfail(q) + qerr = &queryError{HTTPError, errors.New("Forwarder is in servfail hangover")} + return + } + + // Add padding to the raw query + q, err := AddEdnsPadding(q) + if err != nil { + qerr = &queryError{InternalError, err} + return + } + + // Zero out the query ID. + id := binary.BigEndian.Uint16(q) + binary.BigEndian.PutUint16(q, 0) + req, err := http.NewRequest(http.MethodPost, t.url, bytes.NewBuffer(q)) + if err != nil { + qerr = &queryError{InternalError, err} + return + } + + var hostname string + response, hostname, server, qerr = t.sendRequest(id, req) + + // Restore the query ID. + binary.BigEndian.PutUint16(q, id) + if qerr == nil { + if len(response) >= 2 { + if binary.BigEndian.Uint16(response) == 0 { + binary.BigEndian.PutUint16(response, id) + } else { + qerr = &queryError{BadResponse, errors.New("Nonzero response ID")} + } + } else { + qerr = &queryError{BadResponse, fmt.Errorf("Response length is %d", len(response))} + } + } + + if qerr != nil { + if qerr.status != SendFailed { + t.hangoverLock.Lock() + t.hangoverExpiration = time.Now().Add(hangoverDuration) + t.hangoverLock.Unlock() + } + + response = tryServfail(q) + } else if server != nil { + // Record a working IP address for this server iff qerr is nil + t.ips.Get(hostname).Confirm(server.IP) + } + return +} + +func (t *transport) sendRequest(id uint16, req *http.Request) (response []byte, hostname string, server *net.TCPAddr, qerr *queryError) { + hostname = t.hostname + + // The connection used for this request. If the request fails, we will close + // this socket, in case it is no longer functioning. + var conn net.Conn + + // Error cleanup function. If the query fails, this function will close the + // underlying socket and disconfirm the server IP. Empirically, sockets often + // become unresponsive after a network change, causing timeouts on all requests. + defer func() { + if qerr == nil { + return + } + log.Infof("%d Query failed: %v", id, qerr) + if server != nil { + log.Debugf("%d Disconfirming %s", id, server.IP.String()) + t.ips.Get(hostname).Disconfirm(server.IP) + } + if conn != nil { + log.Infof("%d Closing failing DoH socket", id) + conn.Close() + } + }() + + // Add a trace to the request in order to expose the server's IP address. + // Only GotConn performs any action; the other methods just provide debug logs. + // GotConn runs before client.Do() returns, so there is no data race when + // reading the variables it has set. + trace := httptrace.ClientTrace{ + GetConn: func(hostPort string) { + log.Debugf("%d GetConn(%s)", id, hostPort) + }, + GotConn: func(info httptrace.GotConnInfo) { + log.Debugf("%d GotConn(%v)", id, info) + if info.Conn == nil { + return + } + conn = info.Conn + // info.Conn is a DuplexConn, so RemoteAddr is actually a TCPAddr. + server = conn.RemoteAddr().(*net.TCPAddr) + }, + PutIdleConn: func(err error) { + log.Debugf("%d PutIdleConn(%v)", id, err) + }, + GotFirstResponseByte: func() { + log.Debugf("%d GotFirstResponseByte()", id) + }, + Got100Continue: func() { + log.Debugf("%d Got100Continue()", id) + }, + Got1xxResponse: func(code int, header textproto.MIMEHeader) error { + log.Debugf("%d Got1xxResponse(%d, %v)", id, code, header) + return nil + }, + DNSStart: func(info httptrace.DNSStartInfo) { + log.Debugf("%d DNSStart(%v)", id, info) + }, + DNSDone: func(info httptrace.DNSDoneInfo) { + log.Debugf("%d, DNSDone(%v)", id, info) + }, + ConnectStart: func(network, addr string) { + log.Debugf("%d ConnectStart(%s, %s)", id, network, addr) + }, + ConnectDone: func(network, addr string, err error) { + log.Debugf("%d ConnectDone(%s, %s, %v)", id, network, addr, err) + }, + TLSHandshakeStart: func() { + log.Debugf("%d TLSHandshakeStart()", id) + }, + TLSHandshakeDone: func(state tls.ConnectionState, err error) { + log.Debugf("%d TLSHandshakeDone(%v, %v)", id, state, err) + }, + WroteHeaders: func() { + log.Debugf("%d WroteHeaders()", id) + }, + WroteRequest: func(info httptrace.WroteRequestInfo) { + log.Debugf("%d WroteRequest(%v)", id, info) + }, + } + req = req.WithContext(httptrace.WithClientTrace(req.Context(), &trace)) + + const mimetype = "application/dns-message" + req.Header.Set("Content-Type", mimetype) + req.Header.Set("Accept", mimetype) + req.Header.Set("User-Agent", "Intra") + log.Debugf("%d Sending query", id) + httpResponse, err := t.client.Do(req) + if err != nil { + qerr = &queryError{SendFailed, err} + return + } + log.Debugf("%d Got response", id) + response, err = ioutil.ReadAll(httpResponse.Body) + if err != nil { + qerr = &queryError{BadResponse, err} + return + } + httpResponse.Body.Close() + log.Debugf("%d Closed response", id) + + // Update the hostname, which could have changed due to a redirect. + hostname = httpResponse.Request.URL.Hostname() + + if httpResponse.StatusCode != http.StatusOK { + reqBuf := new(bytes.Buffer) + req.Write(reqBuf) + respBuf := new(bytes.Buffer) + httpResponse.Write(respBuf) + log.Debugf("%d request: %s\nresponse: %s", id, reqBuf.String(), respBuf.String()) + + qerr = &queryError{HTTPError, &httpError{httpResponse.StatusCode}} + return + } + + return +} + +func (t *transport) Query(q []byte) ([]byte, error) { + var token Token + if t.listener != nil { + token = t.listener.OnQuery(t.url) + } + + before := time.Now() + response, server, qerr := t.doQuery(q) + after := time.Now() + + var err error + status := Complete + httpStatus := http.StatusOK + if qerr != nil { + err = qerr + status = qerr.status + httpStatus = 0 + + var herr *httpError + if errors.As(qerr.err, &herr) { + httpStatus = herr.status + } + } + + if t.listener != nil { + latency := after.Sub(before) + var ip string + if server != nil { + ip = server.IP.String() + } + + t.listener.OnResponse(token, &Summary{ + Latency: latency.Seconds(), + Query: q, + Response: response, + Server: ip, + Status: status, + HTTPStatus: httpStatus, + }) + } + return response, err +} + +func (t *transport) GetURL() string { + return t.url +} + +// Perform a query using the transport, and send the response to the writer. +func forwardQuery(t Transport, q []byte, c io.Writer) error { + resp, qerr := t.Query(q) + if resp == nil && qerr != nil { + return qerr + } + rlen := len(resp) + if rlen > math.MaxUint16 { + return fmt.Errorf("Oversize response: %d", rlen) + } + // Use a combined write to ensure atomicity. Otherwise, writes from two + // responses could be interleaved. + rlbuf := make([]byte, rlen+2) + binary.BigEndian.PutUint16(rlbuf, uint16(rlen)) + copy(rlbuf[2:], resp) + n, err := c.Write(rlbuf) + if err != nil { + return err + } + if int(n) != len(rlbuf) { + return fmt.Errorf("Incomplete response write: %d < %d", n, len(rlbuf)) + } + return qerr +} + +// Perform a query using the transport, send the response to the writer, +// and close the writer if there was an error. +func forwardQueryAndCheck(t Transport, q []byte, c io.WriteCloser) { + if err := forwardQuery(t, q, c); err != nil { + log.Warnf("Query forwarding failed: %v", err) + c.Close() + } +} + +// Accept a DNS-over-TCP socket from a stub resolver, and connect the socket +// to this DNSTransport. +func Accept(t Transport, c io.ReadWriteCloser) { + qlbuf := make([]byte, 2) + for { + n, err := c.Read(qlbuf) + if n == 0 { + log.Debugf("TCP query socket clean shutdown") + break + } + if err != nil { + log.Warnf("Error reading from TCP query socket: %v", err) + break + } + if n < 2 { + log.Warnf("Incomplete query length") + break + } + qlen := binary.BigEndian.Uint16(qlbuf) + q := make([]byte, qlen) + n, err = c.Read(q) + if err != nil { + log.Warnf("Error reading query: %v", err) + break + } + if n != int(qlen) { + log.Warnf("Incomplete query: %d < %d", n, qlen) + break + } + go forwardQueryAndCheck(t, q, c) + } + // TODO: Cancel outstanding queries at this point. + c.Close() +} + +// Servfail returns a SERVFAIL response to the query q. +func Servfail(q []byte) ([]byte, error) { + var msg dnsmessage.Message + if err := msg.Unpack(q); err != nil { + return nil, err + } + msg.Response = true + msg.RecursionAvailable = true + msg.RCode = dnsmessage.RCodeServerFailure + msg.Additionals = nil // Strip EDNS + return msg.Pack() +} + +func tryServfail(q []byte) []byte { + response, err := Servfail(q) + if err != nil { + log.Warnf("Error constructing servfail: %v", err) + } + return response +} diff --git a/Android/tun2socks/intra/doh/doh_test.go b/Android/tun2socks/intra/doh/doh_test.go new file mode 100644 index 00000000..7ad0797a --- /dev/null +++ b/Android/tun2socks/intra/doh/doh_test.go @@ -0,0 +1,891 @@ +// Copyright 2023 Jigsaw Operations LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package doh + +import ( + "bytes" + "encoding/binary" + "errors" + "io" + "io/ioutil" + "net" + "net/http" + "net/http/httptrace" + "net/url" + "reflect" + "testing" + + "golang.org/x/net/dns/dnsmessage" +) + +var testURL = "https://dns.google/dns-query" +var ips = []string{ + "8.8.8.8", + "8.8.4.4", + "2001:4860:4860::8888", + "2001:4860:4860::8844", +} +var parsedURL *url.URL + +var simpleQuery dnsmessage.Message = dnsmessage.Message{ + Header: dnsmessage.Header{ + ID: 0xbeef, + Response: false, + OpCode: 0, + Authoritative: false, + Truncated: false, + RecursionDesired: true, + RecursionAvailable: false, + RCode: 0, + }, + Questions: []dnsmessage.Question{ + { + Name: dnsmessage.MustNewName("www.example.com."), + Type: dnsmessage.TypeA, + Class: dnsmessage.ClassINET, + }}, + Answers: []dnsmessage.Resource{}, + Authorities: []dnsmessage.Resource{}, + Additionals: []dnsmessage.Resource{}, +} + +func mustPack(m *dnsmessage.Message) []byte { + packed, err := m.Pack() + if err != nil { + panic(err) + } + return packed +} + +func mustUnpack(q []byte) *dnsmessage.Message { + var m dnsmessage.Message + err := m.Unpack(q) + if err != nil { + panic(err) + } + return &m +} + +var simpleQueryBytes []byte = mustPack(&simpleQuery) + +var compressedQueryBytes []byte = []byte{ + 0xbe, 0xef, // ID + 0x01, // QR, OPCODE, AA, TC, RD + 0x00, // RA, Z, RCODE + 0x00, 0x02, // QDCOUNT = 2 + 0x00, 0x00, // ANCOUNT = 0 + 0x00, 0x00, // NSCOUNT = 0 + 0x00, 0x00, // ARCOUNT = 0 + // Question 1 + 0x03, 'f', 'o', 'o', + 0x03, 'b', 'a', 'r', + 0x00, + 0x00, 0x01, // QTYPE: A query + 0x00, 0x01, // QCLASS: IN + // Question 2 + 0xc0, 12, // Pointer to beginning of "foo.bar." + 0x00, 0x01, // QTYPE: A query + 0x00, 0x01, // QCLASS: IN +} + +var uncompressedQueryBytes []byte = []byte{ + 0xbe, 0xef, // ID + 0x01, // QR, OPCODE, AA, TC, RD + 0x00, // RA, Z, RCODE + 0x00, 0x02, // QDCOUNT = 2 + 0x00, 0x00, // ANCOUNT = 0 + 0x00, 0x00, // NSCOUNT = 0 + 0x00, 0x00, // ARCOUNT = 0 + // Question 1 + 0x03, 'f', 'o', 'o', + 0x03, 'b', 'a', 'r', + 0x00, + 0x00, 0x01, // QTYPE: A query + 0x00, 0x01, // QCLASS: IN + // Question 2 + 0x03, 'f', 'o', 'o', + 0x03, 'b', 'a', 'r', + 0x00, + 0x00, 0x01, // QTYPE: A query + 0x00, 0x01, // QCLASS: IN +} + +func init() { + parsedURL, _ = url.Parse(testURL) +} + +// Check that the constructor works. +func TestNewTransport(t *testing.T) { + _, err := NewTransport(testURL, ips, nil, nil, nil) + if err != nil { + t.Fatal(err) + } +} + +// Check that the constructor rejects unsupported URLs. +func TestBadUrl(t *testing.T) { + _, err := NewTransport("ftp://www.example.com", nil, nil, nil, nil) + if err == nil { + t.Error("Expected error") + } + _, err = NewTransport("https://www.example", nil, nil, nil, nil) + if err == nil { + t.Error("Expected error") + } +} + +// Check for failure when the query is too short to be valid. +func TestShortQuery(t *testing.T) { + var qerr *queryError + doh, _ := NewTransport(testURL, ips, nil, nil, nil) + _, err := doh.Query([]byte{}) + if err == nil { + t.Error("Empty query should fail") + } else if !errors.As(err, &qerr) { + t.Errorf("Wrong error type: %v", err) + } else if qerr.status != BadQuery { + t.Errorf("Wrong error status: %d", qerr.status) + } + + _, err = doh.Query([]byte{1}) + if err == nil { + t.Error("One byte query should fail") + } else if !errors.As(err, &qerr) { + t.Errorf("Wrong error type: %v", err) + } else if qerr.status != BadQuery { + t.Errorf("Wrong error status: %d", qerr.status) + } +} + +// Send a DoH query to an actual DoH server +func TestQueryIntegration(t *testing.T) { + queryData := []byte{ + 111, 222, // [0-1] query ID + 1, 0, // [2-3] flags, RD=1 + 0, 1, // [4-5] QDCOUNT (number of queries) = 1 + 0, 0, // [6-7] ANCOUNT (number of answers) = 0 + 0, 0, // [8-9] NSCOUNT (number of authoritative answers) = 0 + 0, 0, // [10-11] ARCOUNT (number of additional records) = 0 + // Start of first query + 7, 'y', 'o', 'u', 't', 'u', 'b', 'e', + 3, 'c', 'o', 'm', + 0, // null terminator of FQDN (DNS root) + 0, 1, // QTYPE = A + 0, 1, // QCLASS = IN (Internet) + } + + testQuery := func(queryData []byte) { + + doh, err := NewTransport(testURL, ips, nil, nil, nil) + if err != nil { + t.Fatal(err) + } + resp, err2 := doh.Query(queryData) + if err2 != nil { + t.Fatal(err2) + } + if resp[0] != queryData[0] || resp[1] != queryData[1] { + t.Error("Query ID mismatch") + } + if len(resp) <= len(queryData) { + t.Error("Response is short") + } + } + + testQuery(queryData) + + paddedQueryBytes, err := AddEdnsPadding(simpleQueryBytes) + if err != nil { + t.Fatal(err) + } + + testQuery(paddedQueryBytes) +} + +type testRoundTripper struct { + http.RoundTripper + req chan *http.Request + resp chan *http.Response + err error +} + +func makeTestRoundTripper() *testRoundTripper { + return &testRoundTripper{ + req: make(chan *http.Request), + resp: make(chan *http.Response), + } +} + +func (r *testRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + if r.err != nil { + return nil, r.err + } + r.req <- req + return <-r.resp, nil +} + +// Check that a DNS query is converted correctly into an HTTP query. +func TestRequest(t *testing.T) { + doh, _ := NewTransport(testURL, ips, nil, nil, nil) + transport := doh.(*transport) + rt := makeTestRoundTripper() + transport.client.Transport = rt + go doh.Query(simpleQueryBytes) + req := <-rt.req + if req.URL.String() != testURL { + t.Errorf("URL mismatch: %s != %s", req.URL.String(), testURL) + } + reqBody, err := ioutil.ReadAll(req.Body) + if err != nil { + t.Error(err) + } + if len(reqBody)%PaddingBlockSize != 0 { + t.Errorf("reqBody has unexpected length: %d", len(reqBody)) + } + // Parse reqBody into a Message. + newQuery := mustUnpack(reqBody) + // Ensure the converted request has an ID of zero. + if newQuery.Header.ID != 0 { + t.Errorf("Unexpected request header id: %v", newQuery.Header.ID) + } + // Check that all fields except for Header.ID and Additionals + // are the same as the original. Additionals may differ if + // padding was added. + if !queriesMostlyEqual(simpleQuery, *newQuery) { + t.Errorf("Unexpected query body:\n\t%v\nExpected:\n\t%v", newQuery, simpleQuery) + } + contentType := req.Header.Get("Content-Type") + if contentType != "application/dns-message" { + t.Errorf("Wrong content type: %s", contentType) + } + accept := req.Header.Get("Accept") + if accept != "application/dns-message" { + t.Errorf("Wrong Accept header: %s", accept) + } +} + +// Check that all fields of m1 match those of m2, except for Header.ID +// and Additionals. +func queriesMostlyEqual(m1 dnsmessage.Message, m2 dnsmessage.Message) bool { + // Make fields we don't care about match, so that equality check is easy. + m1.Header.ID = m2.Header.ID + m1.Additionals = m2.Additionals + return reflect.DeepEqual(m1, m2) +} + +// Check that a DOH response is returned correctly. +func TestResponse(t *testing.T) { + doh, _ := NewTransport(testURL, ips, nil, nil, nil) + transport := doh.(*transport) + rt := makeTestRoundTripper() + transport.client.Transport = rt + + // Fake server. + go func() { + <-rt.req + r, w := io.Pipe() + rt.resp <- &http.Response{ + StatusCode: 200, + Body: r, + Request: &http.Request{URL: parsedURL}, + } + // The DOH response should have a zero query ID. + var modifiedQuery dnsmessage.Message = simpleQuery + modifiedQuery.Header.ID = 0 + w.Write(mustPack(&modifiedQuery)) + w.Close() + }() + + resp, err := doh.Query(simpleQueryBytes) + if err != nil { + t.Error(err) + } + + // Parse the response as a DNS message. + respParsed := mustUnpack(resp) + + // Query() should reconstitute the query ID in the response. + if respParsed.Header.ID != simpleQuery.Header.ID || + !queriesMostlyEqual(*respParsed, simpleQuery) { + t.Errorf("Unexpected response %v", resp) + } +} + +// Simulate an empty response. (This is not a compliant server +// behavior.) +func TestEmptyResponse(t *testing.T) { + doh, _ := NewTransport(testURL, ips, nil, nil, nil) + transport := doh.(*transport) + rt := makeTestRoundTripper() + transport.client.Transport = rt + + // Fake server. + go func() { + <-rt.req + // Make an empty body. + r, w := io.Pipe() + w.Close() + rt.resp <- &http.Response{ + StatusCode: 200, + Body: r, + Request: &http.Request{URL: parsedURL}, + } + }() + + _, err := doh.Query(simpleQueryBytes) + var qerr *queryError + if err == nil { + t.Error("Empty body should cause an error") + } else if !errors.As(err, &qerr) { + t.Errorf("Wrong error type: %v", err) + } else if qerr.status != BadResponse { + t.Errorf("Wrong error status: %d", qerr.status) + } +} + +// Simulate a non-200 HTTP response code. +func TestHTTPError(t *testing.T) { + doh, _ := NewTransport(testURL, ips, nil, nil, nil) + transport := doh.(*transport) + rt := makeTestRoundTripper() + transport.client.Transport = rt + + go func() { + <-rt.req + r, w := io.Pipe() + rt.resp <- &http.Response{ + StatusCode: 500, + Body: r, + Request: &http.Request{URL: parsedURL}, + } + w.Write([]byte{0, 0, 8, 9, 10}) + w.Close() + }() + + _, err := doh.Query(simpleQueryBytes) + var qerr *queryError + if err == nil { + t.Error("Empty body should cause an error") + } else if !errors.As(err, &qerr) { + t.Errorf("Wrong error type: %v", err) + } else if qerr.status != HTTPError { + t.Errorf("Wrong error status: %d", qerr.status) + } +} + +// Simulate an HTTP query error. +func TestSendFailed(t *testing.T) { + doh, _ := NewTransport(testURL, ips, nil, nil, nil) + transport := doh.(*transport) + rt := makeTestRoundTripper() + transport.client.Transport = rt + + rt.err = errors.New("test") + _, err := doh.Query(simpleQueryBytes) + var qerr *queryError + if err == nil { + t.Error("Send failure should be reported") + } else if !errors.As(err, &qerr) { + t.Errorf("Wrong error type: %v", err) + } else if qerr.status != SendFailed { + t.Errorf("Wrong error status: %d", qerr.status) + } else if !errors.Is(qerr, rt.err) { + t.Errorf("Underlying error is not retained") + } +} + +// Test if DoH resolver IPs are confirmed and disconfirmed +// when queries suceeded and fail, respectively. +func TestDohIPConfirmDisconfirm(t *testing.T) { + u, _ := url.Parse(testURL) + doh, _ := NewTransport(testURL, ips, nil, nil, nil) + transport := doh.(*transport) + hostname := u.Hostname() + ipmap := transport.ips.Get(hostname) + + // send a valid request to first have confirmed-ip set + res, _ := doh.Query(simpleQueryBytes) + mustUnpack(res) + ip1 := ipmap.Confirmed() + + if ip1 == nil { + t.Errorf("IP not confirmed despite valid query to %s", u) + } + + // simulate http-fail with doh server-ip set to previously confirmed-ip + rt := makeTestRoundTripper() + transport.client.Transport = rt + go func() { + req := <-rt.req + trace := httptrace.ContextClientTrace(req.Context()) + trace.GotConn(httptrace.GotConnInfo{ + Conn: &fakeConn{ + remoteAddr: &net.TCPAddr{ + IP: ip1, // confirmed-ip from before + Port: 443, + }}}) + rt.resp <- &http.Response{ + StatusCode: 509, // some non-2xx status + Body: nil, + Request: &http.Request{URL: u}, + } + }() + doh.Query(simpleQueryBytes) + ip2 := ipmap.Confirmed() + + if ip2 != nil { + t.Errorf("IP confirmed (%s) despite err", ip2) + } +} + +type fakeListener struct { + Listener + summary *Summary +} + +func (l *fakeListener) OnQuery(url string) Token { + return nil +} + +func (l *fakeListener) OnResponse(tok Token, summ *Summary) { + l.summary = summ +} + +type fakeConn struct { + net.TCPConn + remoteAddr *net.TCPAddr +} + +func (c *fakeConn) RemoteAddr() net.Addr { + return c.remoteAddr +} + +// Check that the DNSListener is called with a correct summary. +func TestListener(t *testing.T) { + listener := &fakeListener{} + doh, _ := NewTransport(testURL, ips, nil, nil, listener) + transport := doh.(*transport) + rt := makeTestRoundTripper() + transport.client.Transport = rt + + go func() { + req := <-rt.req + trace := httptrace.ContextClientTrace(req.Context()) + trace.GotConn(httptrace.GotConnInfo{ + Conn: &fakeConn{ + remoteAddr: &net.TCPAddr{ + IP: net.ParseIP("192.0.2.2"), + Port: 443, + }}}) + + r, w := io.Pipe() + rt.resp <- &http.Response{ + StatusCode: 200, + Body: r, + Request: &http.Request{URL: parsedURL}, + } + w.Write([]byte{0, 0, 8, 9, 10}) + w.Close() + }() + + doh.Query(simpleQueryBytes) + s := listener.summary + if s.Latency < 0 { + t.Errorf("Negative latency: %f", s.Latency) + } + if !bytes.Equal(s.Query, simpleQueryBytes) { + t.Errorf("Wrong query: %v", s.Query) + } + if !bytes.Equal(s.Response, []byte{0xbe, 0xef, 8, 9, 10}) { + t.Errorf("Wrong response: %v", s.Response) + } + if s.Server != "192.0.2.2" { + t.Errorf("Wrong server IP string: %s", s.Server) + } + if s.Status != Complete { + t.Errorf("Wrong status: %d", s.Status) + } +} + +type socket struct { + r io.ReadCloser + w io.WriteCloser +} + +func (c *socket) Read(b []byte) (int, error) { + return c.r.Read(b) +} + +func (c *socket) Write(b []byte) (int, error) { + return c.w.Write(b) +} + +func (c *socket) Close() error { + e1 := c.r.Close() + e2 := c.w.Close() + if e1 != nil { + return e1 + } + return e2 +} + +func makePair() (io.ReadWriteCloser, io.ReadWriteCloser) { + r1, w1 := io.Pipe() + r2, w2 := io.Pipe() + return &socket{r1, w2}, &socket{r2, w1} +} + +type fakeTransport struct { + Transport + query chan []byte + response chan []byte + err error +} + +func (t *fakeTransport) Query(q []byte) ([]byte, error) { + t.query <- q + if t.err != nil { + return nil, t.err + } + return <-t.response, nil +} + +func (t *fakeTransport) GetURL() string { + return "fake" +} + +func (t *fakeTransport) Close() { + t.err = errors.New("closed") + close(t.query) + close(t.response) +} + +func newFakeTransport() *fakeTransport { + return &fakeTransport{ + query: make(chan []byte), + response: make(chan []byte), + } +} + +// Test a successful query over TCP +func TestAccept(t *testing.T) { + doh := newFakeTransport() + client, server := makePair() + + // Start the forwarder running. + go Accept(doh, server) + + lbuf := make([]byte, 2) + // Send Query + queryData := simpleQueryBytes + binary.BigEndian.PutUint16(lbuf, uint16(len(queryData))) + n, err := client.Write(lbuf) + if err != nil { + t.Fatal(err) + } + if n != 2 { + t.Error("Length write problem") + } + n, err = client.Write(queryData) + if err != nil { + t.Fatal(err) + } + if n != len(queryData) { + t.Error("Query write problem") + } + + // Read query + queryRead := <-doh.query + if !bytes.Equal(queryRead, queryData) { + t.Error("Query mismatch") + } + + // Send fake response + responseData := []byte{1, 2, 8, 9, 10} + doh.response <- responseData + + // Get Response + n, err = client.Read(lbuf) + if err != nil { + t.Fatal(err) + } + if n != 2 { + t.Error("Length read problem") + } + rlen := binary.BigEndian.Uint16(lbuf) + resp := make([]byte, int(rlen)) + n, err = client.Read(resp) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(responseData, resp) { + t.Error("Response mismatch") + } + + client.Close() +} + +// Sends a TCP query that results in failure. When a query fails, +// Accept should close the TCP socket. +func TestAcceptFail(t *testing.T) { + doh := newFakeTransport() + client, server := makePair() + + // Start the forwarder running. + go Accept(doh, server) + + lbuf := make([]byte, 2) + // Send Query + queryData := simpleQueryBytes + binary.BigEndian.PutUint16(lbuf, uint16(len(queryData))) + client.Write(lbuf) + client.Write(queryData) + + // Indicate that the query failed + doh.err = errors.New("fake error") + + // Read query + queryRead := <-doh.query + if !bytes.Equal(queryRead, queryData) { + t.Error("Query mismatch") + } + + // Accept should have closed the socket. + n, _ := client.Read(lbuf) + if n != 0 { + t.Error("Expected to read 0 bytes") + } +} + +// Sends a TCP query, and closes the socket before the response is sent. +// This tests for crashes when a response cannot be delivered. +func TestAcceptClose(t *testing.T) { + doh := newFakeTransport() + client, server := makePair() + + // Start the forwarder running. + go Accept(doh, server) + + lbuf := make([]byte, 2) + // Send Query + queryData := simpleQueryBytes + binary.BigEndian.PutUint16(lbuf, uint16(len(queryData))) + client.Write(lbuf) + client.Write(queryData) + + // Read query + queryRead := <-doh.query + if !bytes.Equal(queryRead, queryData) { + t.Error("Query mismatch") + } + + // Close the TCP connection + client.Close() + + // Send fake response too late. + responseData := []byte{1, 2, 8, 9, 10} + doh.response <- responseData +} + +// Test failure due to a response that is larger than the +// maximum message size for DNS over TCP (65535). +func TestAcceptOversize(t *testing.T) { + doh := newFakeTransport() + client, server := makePair() + + // Start the forwarder running. + go Accept(doh, server) + + lbuf := make([]byte, 2) + // Send Query + queryData := simpleQueryBytes + binary.BigEndian.PutUint16(lbuf, uint16(len(queryData))) + client.Write(lbuf) + client.Write(queryData) + + // Read query + <-doh.query + + // Send oversize response + doh.response <- make([]byte, 65536) + + // Accept should have closed the socket because the response + // cannot be written. + n, _ := client.Read(lbuf) + if n != 0 { + t.Error("Expected to read 0 bytes") + } +} + +func TestComputePaddingSize(t *testing.T) { + if computePaddingSize(100-kOptPaddingHeaderLen, 100) != 0 { + t.Errorf("Expected no padding") + } + if computePaddingSize(200-kOptPaddingHeaderLen, 100) != 0 { + t.Errorf("Expected no padding") + } + if computePaddingSize(190-kOptPaddingHeaderLen, 100) != 10 { + t.Errorf("Expected to pad up to next block") + } +} + +func TestAddEdnsPaddingIdempotent(t *testing.T) { + padded, err := AddEdnsPadding(simpleQueryBytes) + if err != nil { + t.Fatal(err) + } + paddedAgain, err := AddEdnsPadding(padded) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(padded, paddedAgain) { + t.Errorf("Padding should be idempotent\n%v\n%v", padded, paddedAgain) + } +} + +// Check that packing |compressedQueryBytes| constructs the same query +// byte-for-byte. +func TestDnsMessageCompressedQueryConfidenceCheck(t *testing.T) { + m := mustUnpack(compressedQueryBytes) + packedBytes := mustPack(m) + if len(packedBytes) != len(compressedQueryBytes) { + t.Errorf("Packed query has different size than original:\n %v\n %v", packedBytes, compressedQueryBytes) + } +} + +// Check that packing |uncompressedQueryBytes| constructs a smaller +// query byte-for-byte, since label compression is enabled by default. +func TestDnsMessageUncompressedQueryConfidenceCheck(t *testing.T) { + m := mustUnpack(uncompressedQueryBytes) + packedBytes := mustPack(m) + if len(packedBytes) >= len(uncompressedQueryBytes) { + t.Errorf("Compressed query is not smaller than uncompressed query") + } +} + +// Check that we correctly pad an uncompressed query to the nearest block. +func TestAddEdnsPaddingUncompressedQuery(t *testing.T) { + if len(uncompressedQueryBytes)%PaddingBlockSize == 0 { + t.Errorf("uncompressedQueryBytes does not require padding, so this test is invalid") + } + padded, err := AddEdnsPadding(uncompressedQueryBytes) + if err != nil { + panic(err) + } + if len(padded)%PaddingBlockSize != 0 { + t.Errorf("AddEdnsPadding failed to correctly pad uncompressed query") + } +} + +// Check that we correctly pad a compressed query to the nearest block. +func TestAddEdnsPaddingCompressedQuery(t *testing.T) { + if len(compressedQueryBytes)%PaddingBlockSize == 0 { + t.Errorf("compressedQueryBytes does not require padding, so this test is invalid") + } + padded, err := AddEdnsPadding(compressedQueryBytes) + if err != nil { + panic(err) + } + if len(padded)%PaddingBlockSize != 0 { + t.Errorf("AddEdnsPadding failed to correctly pad compressed query") + } +} + +// Try to pad a query that already contains an OPT record, but no padding option. +func TestAddEdnsPaddingCompressedOptQuery(t *testing.T) { + optQuery := simpleQuery + optQuery.Additionals = make([]dnsmessage.Resource, len(simpleQuery.Additionals)) + copy(optQuery.Additionals, simpleQuery.Additionals) + + optQuery.Additionals = append(optQuery.Additionals, + dnsmessage.Resource{ + Header: dnsmessage.ResourceHeader{ + Name: dnsmessage.MustNewName("."), + Class: dnsmessage.ClassINET, + TTL: 0, + }, + Body: &dnsmessage.OPTResource{ + Options: []dnsmessage.Option{}, + }, + }, + ) + paddedOnWire, err := AddEdnsPadding(mustPack(&optQuery)) + if err != nil { + t.Errorf("Failed to pad query with OPT but no padding: %v", err) + } + if len(paddedOnWire)%PaddingBlockSize != 0 { + t.Errorf("AddEdnsPadding failed to correctly pad query with OPT but no padding") + } +} + +// Try to pad a query that already contains an OPT record with padding. The +// query should be unmodified by AddEdnsPadding. +func TestAddEdnsPaddingCompressedPaddedQuery(t *testing.T) { + paddedQuery := simpleQuery + paddedQuery.Additionals = make([]dnsmessage.Resource, len(simpleQuery.Additionals)) + copy(paddedQuery.Additionals, simpleQuery.Additionals) + + paddedQuery.Additionals = append(paddedQuery.Additionals, + dnsmessage.Resource{ + Header: dnsmessage.ResourceHeader{ + Name: dnsmessage.MustNewName("."), + Class: dnsmessage.ClassINET, + TTL: 0, + }, + Body: &dnsmessage.OPTResource{ + Options: []dnsmessage.Option{ + { + Code: OptResourcePaddingCode, + Data: make([]byte, 5), + }, + }, + }, + }, + ) + originalOnWire := mustPack(&paddedQuery) + + paddedOnWire, err := AddEdnsPadding(mustPack(&paddedQuery)) + if err != nil { + t.Errorf("Failed to pad padded query: %v", err) + } + + if !bytes.Equal(originalOnWire, paddedOnWire) { + t.Errorf("AddEdnsPadding tampered with a query that was already padded") + } +} + +func TestServfail(t *testing.T) { + sf, err := Servfail(simpleQueryBytes) + if err != nil { + t.Fatal(err) + } + servfail := mustUnpack(sf) + expectedHeader := dnsmessage.Header{ + ID: 0xbeef, + Response: true, + OpCode: 0, + Authoritative: false, + Truncated: false, + RecursionDesired: true, + RecursionAvailable: true, + RCode: 2, + } + if servfail.Header != expectedHeader { + t.Errorf("Wrong header: %v != %v", servfail.Header, expectedHeader) + } + if servfail.Questions[0] != simpleQuery.Questions[0] { + t.Errorf("Wrong question: %v", servfail.Questions[0]) + } +} diff --git a/Android/tun2socks/intra/doh/ipmap/ipmap.go b/Android/tun2socks/intra/doh/ipmap/ipmap.go new file mode 100644 index 00000000..9c94a5db --- /dev/null +++ b/Android/tun2socks/intra/doh/ipmap/ipmap.go @@ -0,0 +1,164 @@ +// Copyright 2023 Jigsaw Operations LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ipmap + +import ( + "context" + "math/rand" + "net" + "sync" + + "github.com/eycorsican/go-tun2socks/common/log" +) + +// IPMap maps hostnames to IPSets. +type IPMap interface { + // Get creates an IPSet for this hostname populated with the IPs + // discovered by resolving it. Subsequent calls to Get return the + // same IPSet. + Get(hostname string) *IPSet +} + +// NewIPMap returns a fresh IPMap. +// `r` will be used to resolve any hostnames passed to `Get` or `Add`. +func NewIPMap(r *net.Resolver) IPMap { + return &ipMap{ + m: make(map[string]*IPSet), + r: r, + } +} + +type ipMap struct { + sync.RWMutex + m map[string]*IPSet + r *net.Resolver +} + +func (m *ipMap) Get(hostname string) *IPSet { + m.RLock() + s := m.m[hostname] + m.RUnlock() + if s != nil { + return s + } + + s = &IPSet{r: m.r} + s.Add(hostname) + + m.Lock() + s2 := m.m[hostname] + if s2 == nil { + m.m[hostname] = s + } else { + // Another pending call to Get populated m[hostname] + // while we were building s. Use that one to ensure + // consistency. + s = s2 + } + m.Unlock() + + return s +} + +// IPSet represents an unordered collection of IP addresses for a single host. +// One IP can be marked as confirmed to be working correctly. +type IPSet struct { + sync.RWMutex + ips []net.IP // All known IPs for the server. + confirmed net.IP // IP address confirmed to be working + r *net.Resolver // Resolver to use for hostname resolution +} + +// Reports whether ip is in the set. Must be called under RLock. +func (s *IPSet) has(ip net.IP) bool { + for _, oldIP := range s.ips { + if oldIP.Equal(ip) { + return true + } + } + return false +} + +// Adds an IP to the set if it is not present. Must be called under Lock. +func (s *IPSet) add(ip net.IP) { + if !s.has(ip) { + s.ips = append(s.ips, ip) + } +} + +// Add one or more IP addresses to the set. +// The hostname can be a domain name or an IP address. +func (s *IPSet) Add(hostname string) { + // Don't hold the ipMap lock during blocking I/O. + resolved, err := s.r.LookupIPAddr(context.TODO(), hostname) + if err != nil { + log.Warnf("Failed to resolve %s: %v", hostname, err) + } + s.Lock() + for _, addr := range resolved { + s.add(addr.IP) + } + s.Unlock() +} + +// Empty reports whether the set is empty. +func (s *IPSet) Empty() bool { + s.RLock() + defer s.RUnlock() + return len(s.ips) == 0 +} + +// GetAll returns a copy of the IP set as a slice in random order. +// The slice is owned by the caller, but the elements are owned by the set. +func (s *IPSet) GetAll() []net.IP { + s.RLock() + c := append([]net.IP{}, s.ips...) + s.RUnlock() + rand.Shuffle(len(c), func(i, j int) { + c[i], c[j] = c[j], c[i] + }) + return c +} + +// Confirmed returns the confirmed IP address, or nil if there is no such address. +func (s *IPSet) Confirmed() net.IP { + s.RLock() + defer s.RUnlock() + return s.confirmed +} + +// Confirm marks ip as the confirmed address. +func (s *IPSet) Confirm(ip net.IP) { + // Optimization: Skip setting if it hasn't changed. + if ip.Equal(s.Confirmed()) { + // This is the common case. + return + } + s.Lock() + // Add is O(N) + s.add(ip) + s.confirmed = ip + s.Unlock() +} + +// Disconfirm sets the confirmed address to nil if the current confirmed address +// is the provided ip. +func (s *IPSet) Disconfirm(ip net.IP) { + s.Lock() + if ip.Equal(s.confirmed) { + s.confirmed = nil + } + s.Unlock() +} diff --git a/Android/tun2socks/intra/doh/ipmap/ipmap_test.go b/Android/tun2socks/intra/doh/ipmap/ipmap_test.go new file mode 100644 index 00000000..be5d9391 --- /dev/null +++ b/Android/tun2socks/intra/doh/ipmap/ipmap_test.go @@ -0,0 +1,177 @@ +// Copyright 2023 Jigsaw Operations LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ipmap + +import ( + "context" + "errors" + "net" + "sync/atomic" + "testing" +) + +// We use '.' at the end to make sure resolution treats it an inexistent root domain. +// It must not resolve to any address. +const invalidDomain = "invaliddomain." + +func TestGetTwice(t *testing.T) { + m := NewIPMap(nil) + a := m.Get("example") + b := m.Get("example") + if a != b { + t.Error("Matched Get returned different objects") + } +} + +func TestGetInvalid(t *testing.T) { + m := NewIPMap(nil) + s := m.Get(invalidDomain) + if !s.Empty() { + t.Errorf("Invalid name should result in an empty set, got %v", s.ips) + } + if len(s.GetAll()) != 0 { + t.Errorf("Empty set should be empty, got %v", s.GetAll()) + } +} + +func TestGetDomain(t *testing.T) { + m := NewIPMap(nil) + s := m.Get("www.google.com") + if s.Empty() { + t.Error("Google lookup failed") + } + ips := s.GetAll() + if len(ips) == 0 { + t.Fatal("IP set is empty") + } + if ips[0] == nil { + t.Error("nil IP in set") + } +} + +func TestGetIP(t *testing.T) { + m := NewIPMap(nil) + s := m.Get("192.0.2.1") + if s.Empty() { + t.Error("IP parsing failed") + } + ips := s.GetAll() + if len(ips) != 1 { + t.Errorf("Wrong IP set size %d", len(ips)) + } + if ips[0].String() != "192.0.2.1" { + t.Error("Wrong IP") + } +} + +func TestAddDomain(t *testing.T) { + m := NewIPMap(nil) + s := m.Get(invalidDomain) + s.Add("www.google.com") + if s.Empty() { + t.Error("Google lookup failed") + } + ips := s.GetAll() + if len(ips) == 0 { + t.Fatal("IP set is empty") + } + if ips[0] == nil { + t.Error("nil IP in set") + } +} +func TestAddIP(t *testing.T) { + m := NewIPMap(nil) + s := m.Get(invalidDomain) + s.Add("192.0.2.1") + ips := s.GetAll() + if len(ips) != 1 { + t.Errorf("Wrong IP set size %d", len(ips)) + } + if ips[0].String() != "192.0.2.1" { + t.Error("Wrong IP") + } +} + +func TestConfirmed(t *testing.T) { + m := NewIPMap(nil) + s := m.Get("www.google.com") + if s.Confirmed() != nil { + t.Error("Confirmed should start out nil") + } + + ips := s.GetAll() + s.Confirm(ips[0]) + if !ips[0].Equal(s.Confirmed()) { + t.Error("Confirmation failed") + } + + s.Disconfirm(ips[0]) + if s.Confirmed() != nil { + t.Error("Confirmed should now be nil") + } +} + +func TestConfirmNew(t *testing.T) { + m := NewIPMap(nil) + s := m.Get(invalidDomain) + s.Add("192.0.2.1") + // Confirm a new address. + s.Confirm(net.ParseIP("192.0.2.2")) + if s.Confirmed() == nil || s.Confirmed().String() != "192.0.2.2" { + t.Error("Confirmation failed") + } + ips := s.GetAll() + if len(ips) != 2 { + t.Error("New address not added to the set") + } +} + +func TestDisconfirmMismatch(t *testing.T) { + m := NewIPMap(nil) + s := m.Get("www.google.com") + ips := s.GetAll() + s.Confirm(ips[0]) + + // Make a copy + otherIP := net.ParseIP(ips[0].String()) + // Alter it + otherIP[0]++ + // Disconfirm. This should have no effect because otherIP + // is not the confirmed IP. + s.Disconfirm(otherIP) + + if !ips[0].Equal(s.Confirmed()) { + t.Error("Mismatched disconfirmation") + } +} + +func TestResolver(t *testing.T) { + var dialCount int32 + resolver := &net.Resolver{ + PreferGo: true, + Dial: func(context context.Context, network, address string) (net.Conn, error) { + atomic.AddInt32(&dialCount, 1) + return nil, errors.New("Fake dialer") + }, + } + m := NewIPMap(resolver) + s := m.Get("www.google.com") + if !s.Empty() { + t.Error("Google lookup should have failed due to fake dialer") + } + if atomic.LoadInt32(&dialCount) == 0 { + t.Error("Fake dialer didn't run") + } +} diff --git a/Android/tun2socks/intra/doh/padding.go b/Android/tun2socks/intra/doh/padding.go new file mode 100644 index 00000000..d515ca5d --- /dev/null +++ b/Android/tun2socks/intra/doh/padding.go @@ -0,0 +1,118 @@ +// Copyright 2023 Jigsaw Operations LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package doh + +import ( + "golang.org/x/net/dns/dnsmessage" +) + +const ( + OptResourcePaddingCode = 12 + PaddingBlockSize = 128 // RFC8467 recommendation +) + +const kOptRrHeaderLen int = 1 + // DOMAIN NAME + 2 + // TYPE + 2 + // CLASS + 4 + // TTL + 2 // RDLEN + +const kOptPaddingHeaderLen int = 2 + // OPTION-CODE + 2 // OPTION-LENGTH + +// Compute the number of padding bytes needed, excluding headers. +// Assumes that |msgLen| is the length of a raw DNS message that contains an +// OPT RR with no RFC7830 padding option, and that the message is fully +// label-compressed. +func computePaddingSize(msgLen int, blockSize int) int { + // We'll always be adding a new padding header inside the OPT + // RR's data. + extraPadding := kOptPaddingHeaderLen + + padSize := blockSize - (msgLen+extraPadding)%blockSize + return padSize % blockSize +} + +// Create an appropriately-sized padding option. Precondition: |msgLen| is the +// length of a message that already contains an OPT RR. +func getPadding(msgLen int) dnsmessage.Option { + optPadding := dnsmessage.Option{ + Code: OptResourcePaddingCode, + Data: make([]byte, computePaddingSize(msgLen, PaddingBlockSize)), + } + return optPadding +} + +// Add EDNS padding, as defined in RFC7830, to a raw DNS message. +func AddEdnsPadding(rawMsg []byte) ([]byte, error) { + var msg dnsmessage.Message + if err := msg.Unpack(rawMsg); err != nil { + return nil, err + } + + // Search for OPT resource and save |optRes| pointer if possible. + var optRes *dnsmessage.OPTResource = nil + for _, additional := range msg.Additionals { + switch body := additional.Body.(type) { + case *dnsmessage.OPTResource: + optRes = body + break + } + } + if optRes != nil { + // Search for a padding Option. If the message already contains + // padding, we will respect the stub resolver's padding. + for _, option := range optRes.Options { + if option.Code == OptResourcePaddingCode { + return rawMsg, nil + } + } + // At this point, |optRes| points to an OPTResource that does + // not contain a padding option. + } else { + // Create an empty OPTResource (contains no padding option) and + // push it into |msg.Additionals|. + optRes = &dnsmessage.OPTResource{ + Options: []dnsmessage.Option{}, + } + + optHeader := dnsmessage.ResourceHeader{} + // SetEDNS0(udpPayloadLen int, extRCode RCode, dnssecOK bool) error + err := optHeader.SetEDNS0(65535, dnsmessage.RCodeSuccess, false) + if err != nil { + return nil, err + } + + msg.Additionals = append(msg.Additionals, dnsmessage.Resource{ + Header: optHeader, + Body: optRes, + }) + } + // At this point, |msg| contains an OPT resource, and that OPT resource + // does not contain a padding option. + + // Compress the message to determine its size before padding. + compressedMsg, err := msg.Pack() + if err != nil { + return nil, err + } + // Add the padding option to |msg| that will round its size on the wire + // up to the nearest block. + paddingOption := getPadding(len(compressedMsg)) + optRes.Options = append(optRes.Options, paddingOption) + + // Re-pack the message, with compression unconditionally enabled. + return msg.Pack() +} diff --git a/Android/tun2socks/intra/ip.go b/Android/tun2socks/intra/ip.go new file mode 100644 index 00000000..11bc265e --- /dev/null +++ b/Android/tun2socks/intra/ip.go @@ -0,0 +1,23 @@ +// Copyright 2023 Jigsaw Operations LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package intra + +import "net/netip" + +// isEquivalentAddrPort checks if addr1 and addr2 are equivalent. More specifically, it will treat +// "ffff::127.0.0.1" (IPv4-in-6) and "127.0.0.1" (IPv4) as equivalent, even though they are "!=" in Go. +func isEquivalentAddrPort(addr1, addr2 netip.AddrPort) bool { + return addr1.Addr().Unmap() == addr2.Addr().Unmap() && addr1.Port() == addr2.Port() +} diff --git a/Android/tun2socks/intra/ip_test.go b/Android/tun2socks/intra/ip_test.go new file mode 100644 index 00000000..ed286609 --- /dev/null +++ b/Android/tun2socks/intra/ip_test.go @@ -0,0 +1,76 @@ +// Copyright 2023 Jigsaw Operations LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package intra + +import ( + "net/netip" + "testing" +) + +func TestIsEquivalentAddrPort(t *testing.T) { + cases := []struct { + in1, in2 netip.AddrPort + want bool + msg string + }{ + { + in1: netip.MustParseAddrPort("12.34.56.78:80"), + in2: netip.AddrPortFrom(netip.AddrFrom4([4]byte{12, 34, 56, 78}), 80), + want: true, + }, + { + in1: netip.MustParseAddrPort("[fe80::1234:5678]:443"), + in2: netip.AddrPortFrom(netip.AddrFrom16([16]byte{0xfe, 0x80, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0x12, 0x34, 0x56, 0x78}), 443), + want: true, + }, + { + in1: netip.MustParseAddrPort("0.0.0.0:80"), + in2: netip.MustParseAddrPort("127.0.0.1:80"), + want: false, + }, + { + in1: netip.AddrPortFrom(netip.IPv6Unspecified(), 80), + in2: netip.AddrPortFrom(netip.IPv6Loopback(), 80), + want: false, + }, + { + in1: netip.MustParseAddrPort("127.0.0.1:38880"), + in2: netip.MustParseAddrPort("127.0.0.1:38888"), + want: false, + }, + { + in1: netip.MustParseAddrPort("[2001:db8:85a3:8d3:1319:8a2e:370:7348]:33443"), + in2: netip.MustParseAddrPort("[2001:db8:85a3:8d3:1319:8a2e:370:7348]:33444"), + want: false, + }, + { + in1: netip.MustParseAddrPort("127.0.0.1:8080"), + in2: netip.MustParseAddrPort("[::ffff:127.0.0.1]:8080"), + want: true, + }, + { + in1: netip.AddrPortFrom(netip.IPv6Loopback(), 80), + in2: netip.MustParseAddrPort("127.0.0.1:80"), + want: false, + }, + } + + for _, tc := range cases { + actual := isEquivalentAddrPort(tc.in1, tc.in2) + if actual != tc.want { + t.Fatalf(`"%v" == "%v"? want %v, actual %v`, tc.in1, tc.in2, tc.want, actual) + } + } +} diff --git a/Android/tun2socks/intra/packet_proxy.go b/Android/tun2socks/intra/packet_proxy.go new file mode 100644 index 00000000..c5f39b63 --- /dev/null +++ b/Android/tun2socks/intra/packet_proxy.go @@ -0,0 +1,164 @@ +// Copyright 2023 Jigsaw Operations LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package intra + +import ( + "errors" + "fmt" + "net" + "net/netip" + "sync/atomic" + "time" + + "github.com/Jigsaw-Code/Intra/Android/tun2socks/intra/doh" + "github.com/Jigsaw-Code/Intra/Android/tun2socks/intra/protect" + "github.com/Jigsaw-Code/outline-sdk/network" + "github.com/Jigsaw-Code/outline-sdk/transport" +) + +type intraPacketProxy struct { + fakeDNSAddr netip.AddrPort + dns atomic.Pointer[doh.Transport] + proxy network.PacketProxy + listener UDPListener +} + +var _ network.PacketProxy = (*intraPacketProxy)(nil) + +func newIntraPacketProxy( + fakeDNS netip.AddrPort, dns doh.Transport, protector protect.Protector, listener UDPListener, +) (*intraPacketProxy, error) { + if dns == nil { + return nil, errors.New("dns is required") + } + + pl := &transport.UDPPacketListener{ + ListenConfig: *protect.MakeListenConfig(protector), + } + + // RFC 4787 REQ-5 requires a timeout no shorter than 5 minutes. + pp, err := network.NewPacketProxyFromPacketListener(pl, network.WithPacketListenerWriteIdleTimeout(5*time.Minute)) + if err != nil { + return nil, fmt.Errorf("failed to create packet proxy from listener: %w", err) + } + + dohpp := &intraPacketProxy{ + fakeDNSAddr: fakeDNS, + proxy: pp, + listener: listener, + } + dohpp.dns.Store(&dns) + + return dohpp, nil +} + +// NewSession implements PacketProxy.NewSession. +func (p *intraPacketProxy) NewSession(resp network.PacketResponseReceiver) (network.PacketRequestSender, error) { + dohResp := &dohPacketRespReceiver{ + PacketResponseReceiver: resp, + stats: makeTracker(), + listener: p.listener, + } + req, err := p.proxy.NewSession(dohResp) + if err != nil { + return nil, fmt.Errorf("failed to create new session: %w", err) + } + + return &dohPacketReqSender{ + PacketRequestSender: req, + proxy: p, + response: dohResp, + stats: dohResp.stats, + }, nil +} + +func (p *intraPacketProxy) SetDNS(dns doh.Transport) error { + if dns == nil { + return errors.New("dns is required") + } + p.dns.Store(&dns) + return nil +} + +// DoH PacketRequestSender wrapper +type dohPacketReqSender struct { + network.PacketRequestSender + + response *dohPacketRespReceiver + proxy *intraPacketProxy + stats *tracker +} + +// DoH PacketResponseReceiver wrapper +type dohPacketRespReceiver struct { + network.PacketResponseReceiver + + stats *tracker + listener UDPListener +} + +var _ network.PacketRequestSender = (*dohPacketReqSender)(nil) +var _ network.PacketResponseReceiver = (*dohPacketRespReceiver)(nil) + +// WriteTo implements PacketRequestSender.WriteTo. It will query the DoH server if the packet a DNS packet. +func (req *dohPacketReqSender) WriteTo(p []byte, destination netip.AddrPort) (int, error) { + if isEquivalentAddrPort(destination, req.proxy.fakeDNSAddr) { + defer func() { + // conn was only used for this DNS query, so it's unlikely to be used again + if req.stats.download.Load() == 0 && req.stats.upload.Load() == 0 { + req.Close() + } + }() + + resp, err := (*req.proxy.dns.Load()).Query(p) + if err != nil { + return 0, fmt.Errorf("DoH request error: %w", err) + } + if len(resp) == 0 { + return 0, errors.New("empty DoH response") + } + + return req.response.writeFrom(resp, net.UDPAddrFromAddrPort(req.proxy.fakeDNSAddr), false) + } + + req.stats.upload.Add(int64(len(p))) + return req.PacketRequestSender.WriteTo(p, destination) +} + +// Close terminates the UDP session, and reports session stats to the listener. +func (resp *dohPacketRespReceiver) Close() error { + if resp.listener != nil { + resp.listener.OnUDPSocketClosed(&UDPSocketSummary{ + Duration: int32(time.Since(resp.stats.start)), + UploadBytes: resp.stats.upload.Load(), + DownloadBytes: resp.stats.download.Load(), + }) + } + return resp.PacketResponseReceiver.Close() +} + +// WriteFrom implements PacketResponseReceiver.WriteFrom. +func (resp *dohPacketRespReceiver) WriteFrom(p []byte, source net.Addr) (int, error) { + return resp.writeFrom(p, source, true) +} + +// writeFrom writes to the underlying PacketResponseReceiver. +// It will also add len(p) to downloadBytes if doStat is true. +func (resp *dohPacketRespReceiver) writeFrom(p []byte, source net.Addr, doStat bool) (int, error) { + if doStat { + resp.stats.download.Add(int64(len(p))) + } + return resp.PacketResponseReceiver.WriteFrom(p, source) +} diff --git a/Android/tun2socks/intra/protect/protect.go b/Android/tun2socks/intra/protect/protect.go new file mode 100644 index 00000000..f3d29acc --- /dev/null +++ b/Android/tun2socks/intra/protect/protect.go @@ -0,0 +1,127 @@ +// Copyright 2023 Jigsaw Operations LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package protect + +import ( + "context" + "errors" + "fmt" + "net" + "strings" + "syscall" + + "github.com/eycorsican/go-tun2socks/common/log" +) + +// Protector provides the ability to bypass a VPN on Android, pre-Lollipop. +type Protector interface { + // Protect a socket, i.e. exclude it from the VPN. + // This is needed in order to avoid routing loops for the VPN's own sockets. + // This is a wrapper for Android's VpnService.protect(). + Protect(socket int32) bool + + // Returns a comma-separated list of the system's configured DNS resolvers, + // in roughly descending priority order. + // This is needed because (1) Android Java cannot protect DNS lookups but Go can, and + // (2) Android Java can determine the list of system DNS resolvers but Go cannot. + // A comma-separated list is used because Gomobile cannot bind []string. + GetResolvers() string +} + +func makeControl(p Protector) func(string, string, syscall.RawConn) error { + return func(network, address string, c syscall.RawConn) error { + return c.Control(func(fd uintptr) { + if !p.Protect(int32(fd)) { + // TODO: Record and report these errors. + log.Errorf("Failed to protect a %s socket", network) + } + }) + } +} + +// Returns the first IP address that is of the desired family. +func scan(ips []string, wantV4 bool) string { + for _, ip := range ips { + parsed := net.ParseIP(ip) + if parsed == nil { + // `ip` failed to parse. Skip it. + continue + } + isV4 := parsed.To4() != nil + if isV4 == wantV4 { + return ip + } + } + return "" +} + +// Given a slice of IP addresses, and a transport address, return a transport +// address with the IP replaced by the first IP of the same family in `ips`, or +// by the first address of a different family if there are none of the same. +func replaceIP(addr string, ips []string) (string, error) { + if len(ips) == 0 { + return "", errors.New("No resolvers available") + } + orighost, port, err := net.SplitHostPort(addr) + if err != nil { + return "", err + } + origip := net.ParseIP(orighost) + if origip == nil { + return "", fmt.Errorf("Can't parse resolver IP: %s", orighost) + } + isV4 := origip.To4() != nil + newIP := scan(ips, isV4) + if newIP == "" { + // There are no IPs of the desired address family. Use a different family. + newIP = ips[0] + } + return net.JoinHostPort(newIP, port), nil +} + +// MakeDialer creates a new Dialer. Recipients can safely mutate +// any public field except Control and Resolver, which are both populated. +func MakeDialer(p Protector) *net.Dialer { + if p == nil { + return &net.Dialer{} + } + d := &net.Dialer{ + Control: makeControl(p), + } + resolverDialer := func(ctx context.Context, network, address string) (net.Conn, error) { + resolvers := strings.Split(p.GetResolvers(), ",") + newAddress, err := replaceIP(address, resolvers) + if err != nil { + return nil, err + } + return d.DialContext(ctx, network, newAddress) + } + d.Resolver = &net.Resolver{ + PreferGo: true, + Dial: resolverDialer, + } + return d +} + +// MakeListenConfig returns a new ListenConfig that creates protected +// listener sockets. +func MakeListenConfig(p Protector) *net.ListenConfig { + if p == nil { + return &net.ListenConfig{} + } + return &net.ListenConfig{ + Control: makeControl(p), + } +} diff --git a/Android/tun2socks/intra/protect/protect_test.go b/Android/tun2socks/intra/protect/protect_test.go new file mode 100644 index 00000000..dc6d7459 --- /dev/null +++ b/Android/tun2socks/intra/protect/protect_test.go @@ -0,0 +1,142 @@ +// Copyright 2023 Jigsaw Operations LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package protect + +import ( + "context" + "net" + "sync" + "syscall" + "testing" +) + +// The fake protector just records the file descriptors it was given. +type fakeProtector struct { + mu sync.Mutex + fds []int32 +} + +func (p *fakeProtector) Protect(fd int32) bool { + p.mu.Lock() + p.fds = append(p.fds, fd) + p.mu.Unlock() + return true +} + +func (p *fakeProtector) GetResolvers() string { + return "8.8.8.8,2001:4860:4860::8888" +} + +// This interface serves as a supertype of net.TCPConn and net.UDPConn, so +// that they can share the verifyMatch() function. +type hasSyscallConn interface { + SyscallConn() (syscall.RawConn, error) +} + +func verifyMatch(t *testing.T, conn hasSyscallConn, p *fakeProtector) { + rawconn, err := conn.SyscallConn() + if err != nil { + t.Fatal(err) + } + rawconn.Control(func(fd uintptr) { + if len(p.fds) == 0 { + t.Fatalf("No file descriptors") + } + if int32(fd) != p.fds[0] { + t.Fatalf("File descriptor mismatch: %d != %d", fd, p.fds[0]) + } + }) +} + +func TestDialTCP(t *testing.T) { + l, err := net.Listen("tcp", "localhost:0") + if err != nil { + t.Fatal(err) + } + go l.Accept() + + p := &fakeProtector{} + d := MakeDialer(p) + if d.Control == nil { + t.Errorf("Control function is nil") + } + + conn, err := d.Dial("tcp", l.Addr().String()) + if err != nil { + t.Fatal(err) + } + verifyMatch(t, conn.(*net.TCPConn), p) + l.Close() + conn.Close() +} + +func TestListenUDP(t *testing.T) { + udpaddr, err := net.ResolveUDPAddr("udp", "localhost:0") + if err != nil { + t.Fatal(err) + } + + p := &fakeProtector{} + c := MakeListenConfig(p) + + conn, err := c.ListenPacket(context.Background(), udpaddr.Network(), udpaddr.String()) + if err != nil { + t.Fatal(err) + } + verifyMatch(t, conn.(*net.UDPConn), p) + conn.Close() +} + +func TestLookupIPAddr(t *testing.T) { + p := &fakeProtector{} + d := MakeDialer(p) + d.Resolver.LookupIPAddr(context.Background(), "foo.test.") + // Verify that Protect was called. + if len(p.fds) == 0 { + t.Fatal("Protect was not called") + } +} + +func TestNilDialer(t *testing.T) { + l, err := net.Listen("tcp", "localhost:0") + if err != nil { + t.Fatal(err) + } + go l.Accept() + + d := MakeDialer(nil) + conn, err := d.Dial("tcp", l.Addr().String()) + if err != nil { + t.Fatal(err) + } + + conn.Close() + l.Close() +} + +func TestNilListener(t *testing.T) { + udpaddr, err := net.ResolveUDPAddr("udp", "localhost:0") + if err != nil { + t.Fatal(err) + } + + c := MakeListenConfig(nil) + conn, err := c.ListenPacket(context.Background(), udpaddr.Network(), udpaddr.String()) + if err != nil { + t.Fatal(err) + } + + conn.Close() +} diff --git a/Android/tun2socks/intra/sni_reporter.go b/Android/tun2socks/intra/sni_reporter.go new file mode 100644 index 00000000..3ac5a109 --- /dev/null +++ b/Android/tun2socks/intra/sni_reporter.go @@ -0,0 +1,116 @@ +// Copyright 2023 Jigsaw Operations LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package intra + +import ( + "io" + "sync" + "time" + + "github.com/Jigsaw-Code/Intra/Android/tun2socks/intra/doh" + "github.com/Jigsaw-Code/choir" + "github.com/eycorsican/go-tun2socks/common/log" +) + +// Number of bins to assign reports to. Should be large enough for +// k-anonymity goals. See the Choir documentation for more info. +const bins = 32 + +// Number of values in each report. The two values are +// * success/failure +// * timeout/close +const values = 2 + +// Burst duration. Only one report will be sent in each interval +// to avoid correlated reports. +const burst = 10 * time.Second + +// tcpSNIReporter is a thread-safe wrapper around choir.Reporter +type tcpSNIReporter struct { + mu sync.RWMutex // Protects dns, suffix, and r. + dns doh.Transport + suffix string + r choir.Reporter +} + +// SetDNS changes the DNS transport used for uploading reports. +func (r *tcpSNIReporter) SetDNS(dns doh.Transport) { + r.mu.Lock() + r.dns = dns + r.mu.Unlock() +} + +// Send implements choir.ReportSender. +func (r *tcpSNIReporter) Send(report choir.Report) error { + r.mu.RLock() + suffix := r.suffix + dns := r.dns + r.mu.RUnlock() + q, err := choir.FormatQuery(report, suffix) + if err != nil { + log.Warnf("Failed to construct query for Choir: %v", err) + return nil + } + if _, err = dns.Query(q); err != nil { + log.Infof("Failed to deliver query for Choir: %v", err) + } + return nil +} + +// Configure initializes or reinitializes the reporter. +// `file` is the Choir salt file (persistent and initially empty). +// `suffix` is the domain to which reports will be sent. +// `country` is the two-letter ISO country code of the user's location. +func (r *tcpSNIReporter) Configure(file io.ReadWriter, suffix, country string) (err error) { + r.mu.Lock() + r.suffix = suffix + r.r, err = choir.NewReporter(file, bins, values, country, burst, r) + r.mu.Unlock() + return +} + +// Report converts `summary` into a Choir report and queues it for delivery. +func (r *tcpSNIReporter) Report(summary TCPSocketSummary) { + if summary.Retry.Split == 0 { + return // Nothing to report + } + + r.mu.RLock() + reporter := r.r + r.mu.RUnlock() + + if reporter == nil { + return // Reports are disabled + } + result := "failed" + if summary.DownloadBytes > 0 { + result = "success" + } + response := "closed" + if summary.Retry.Timeout { + response = "timeout" + } + resultValue, err := choir.NewValue(result) + if err != nil { + log.Fatalf("Bad result %s: %v", result, err) + } + responseValue, err := choir.NewValue(response) + if err != nil { + log.Fatalf("Bad response %s: %v", response, err) + } + if err := reporter.Report(summary.Retry.SNI, resultValue, responseValue); err != nil { + log.Warnf("Choir report failed: %v", err) + } +} diff --git a/Android/tun2socks/intra/sni_reporter_test.go b/Android/tun2socks/intra/sni_reporter_test.go new file mode 100644 index 00000000..9da76b44 --- /dev/null +++ b/Android/tun2socks/intra/sni_reporter_test.go @@ -0,0 +1,209 @@ +// Copyright 2023 Jigsaw Operations LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package intra + +import ( + "bytes" + "errors" + "strings" + "testing" + + "golang.org/x/net/dns/dnsmessage" + + "github.com/Jigsaw-Code/Intra/Android/tun2socks/intra/doh" + "github.com/Jigsaw-Code/Intra/Android/tun2socks/intra/split" +) + +type qfunc func(q []byte) ([]byte, error) + +type fakeTransport struct { + doh.Transport + query qfunc +} + +func (t *fakeTransport) Query(q []byte) ([]byte, error) { + return t.query(q) +} + +func newFakeTransport(query qfunc) *fakeTransport { + return &fakeTransport{query: query} +} + +func sendReport(t *testing.T, r *tcpSNIReporter, summary TCPSocketSummary, response []byte, responseErr error) string { + // This function blocks for the burst duration (10 seconds), so it's important that + // all tests that use it run in parallel to avoid extreme test delays. + t.Parallel() + + c := make(chan string) + dns := newFakeTransport(func(q []byte) ([]byte, error) { + var msg dnsmessage.Message + err := msg.Unpack(q) + if err != nil { + t.Fatal(err) + } + name := msg.Questions[0].Name.String() + c <- name + return response, responseErr + }) + r.SetDNS(dns) + r.Report(summary) + return <-c +} + +const suffix = "mydomain.example" +const country = "zz" + +func runSuccessTest(t *testing.T, summary TCPSocketSummary) string { + r := tcpSNIReporter{} + var stubFile bytes.Buffer + r.Configure(&stubFile, suffix, country) + return sendReport(t, &r, summary, make([]byte, 100), nil) +} + +func TestSuccessClosed(t *testing.T) { + summary := TCPSocketSummary{ + DownloadBytes: 10000, // >0 indicates success + UploadBytes: 5000, + Retry: &split.RetryStats{ + Timeout: false, // Socket was explicitly closed + Split: 48, // >0 indicates a split was attempted + SNI: "user.domain.test", // SNI of the socket + }, + } + name := runSuccessTest(t, summary) + labels := strings.Split(name, ".") + if labels[0] != "success" { + t.Errorf("Bad name %s, %s != success", name, labels[0]) + } + if labels[1] != "closed" { + t.Errorf("Bad name %s, %s != closed", name, labels[1]) + } + // labels[2] is the bin, which is random. + if labels[3] != "zz" { + t.Errorf("Bad name %s, %s != zz", name, labels[1]) + } + // labels[4] is the date, which is not controlled by the code under test. + remainder := strings.Join(labels[5:], ".") + expected := summary.Retry.SNI + "." + suffix + "." + if remainder != expected { + t.Errorf("Bad name %s, %s != %s", name, remainder, expected) + } +} + +func TestTimeout(t *testing.T) { + summary := TCPSocketSummary{ + DownloadBytes: 10000, // >0 indicates success + UploadBytes: 5000, + Retry: &split.RetryStats{ + Timeout: true, // Socket timed out + Split: 54, // >0 indicates a split was attempted + SNI: "user.domain.test", // SNI of the socket + }, + } + name := runSuccessTest(t, summary) + labels := strings.Split(name, ".") + if labels[1] != "timeout" { + t.Errorf("Bad name %s, %s != timeout", name, labels[1]) + } +} + +func TestFail(t *testing.T) { + summary := TCPSocketSummary{ + DownloadBytes: 0, // 0 indicates failure + UploadBytes: 500, + Retry: &split.RetryStats{ + Timeout: true, // Socket timed out + Split: 36, // >0 indicates a split was attempted + SNI: "user.domain.test", // SNI of the socket + }, + } + name := runSuccessTest(t, summary) + labels := strings.Split(name, ".") + if labels[0] != "failed" { + t.Errorf("Bad name %s, %s != failed", name, labels[0]) + } +} + +func TestError(t *testing.T) { + r := tcpSNIReporter{} + var stubFile bytes.Buffer + r.Configure(&stubFile, suffix, country) + summary := TCPSocketSummary{ + DownloadBytes: 5000, + UploadBytes: 500, + Retry: &split.RetryStats{ + Timeout: true, + Split: 36, + SNI: "user.domain.test", + }, + } + // Verify that I/O errors don't cause a panic. + sendReport(t, &r, summary, nil, errors.New("DNS send failed")) +} + +func TestNoSplit(t *testing.T) { + r := tcpSNIReporter{} + var stubFile bytes.Buffer + r.Configure(&stubFile, suffix, country) + summary := TCPSocketSummary{ + DownloadBytes: 5000, + UploadBytes: 500, + Retry: &split.RetryStats{ + Timeout: true, + Split: 0, + SNI: "user.domain.test", + }, + } + dns := newFakeTransport(func(q []byte) ([]byte, error) { + t.Error("DNS query function should not be called because no split was performed") + return nil, errors.New("Unreachable") + }) + r.SetDNS(dns) + r.Report(summary) +} + +func TestUnconfigured(t *testing.T) { + r := tcpSNIReporter{} + summary := TCPSocketSummary{ + DownloadBytes: 5000, + UploadBytes: 500, + Retry: &split.RetryStats{ + Timeout: true, + Split: 45, + SNI: "user.domain.test", + }, + } + dns := newFakeTransport(func(q []byte) ([]byte, error) { + t.Error("DNS query function should not be called because the reporter is not configured") + return nil, errors.New("Unreachable") + }) + r.SetDNS(dns) + r.Report(summary) +} + +func TestNoDNS(t *testing.T) { + r := tcpSNIReporter{} + summary := TCPSocketSummary{ + DownloadBytes: 5000, + UploadBytes: 500, + Retry: &split.RetryStats{ + Timeout: true, + Split: 45, + SNI: "user.domain.test", + }, + } + // Verify that this doesn't panic. + r.Report(summary) +} diff --git a/Android/tun2socks/intra/split/direct_split.go b/Android/tun2socks/intra/split/direct_split.go new file mode 100644 index 00000000..16e3e730 --- /dev/null +++ b/Android/tun2socks/intra/split/direct_split.go @@ -0,0 +1,80 @@ +// Copyright 2023 Jigsaw Operations LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package split + +import ( + "io" + "net" +) + +// DuplexConn represents a bidirectional stream socket. +type DuplexConn interface { + net.Conn + io.ReaderFrom + CloseWrite() error + CloseRead() error +} + +type splitter struct { + *net.TCPConn + used bool // Initially false. Becomes true after the first write. +} + +// DialWithSplit returns a TCP connection that always splits the initial upstream segment. +// Like net.Conn, it is intended for two-threaded use, with one thread calling +// Read and CloseRead, and another calling Write, ReadFrom, and CloseWrite. +func DialWithSplit(d *net.Dialer, addr *net.TCPAddr) (DuplexConn, error) { + conn, err := d.Dial(addr.Network(), addr.String()) + if err != nil { + return nil, err + } + + return &splitter{TCPConn: conn.(*net.TCPConn)}, nil +} + +// Write-related functions +func (s *splitter) Write(b []byte) (int, error) { + conn := s.TCPConn + if s.used { + // After the first write, there is no special write behavior. + return conn.Write(b) + } + + // Setting `used` to true ensures that this code only runs once per socket. + s.used = true + b1, b2 := splitHello(b) + n1, err := conn.Write(b1) + if err != nil { + return n1, err + } + n2, err := conn.Write(b2) + return n1 + n2, err +} + +func (s *splitter) ReadFrom(reader io.Reader) (bytes int64, err error) { + if !s.used { + // This is the first write on this socket. + // Use copyOnce(), which calls Write(), to get Write's splitting behavior for + // the first segment. + if bytes, err = copyOnce(s, reader); err != nil { + return + } + } + + var b int64 + b, err = s.TCPConn.ReadFrom(reader) + bytes += b + return +} diff --git a/Android/tun2socks/intra/split/example/main.go b/Android/tun2socks/intra/split/example/main.go new file mode 100644 index 00000000..850662af --- /dev/null +++ b/Android/tun2socks/intra/split/example/main.go @@ -0,0 +1,80 @@ +// Copyright 2023 Jigsaw Operations LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package main + +import ( + "crypto/tls" + "flag" + "fmt" + "log" + "net" + "os" + + "github.com/Jigsaw-Code/Intra/Android/tun2socks/intra/split" +) + +func main() { + flag.Usage = func() { + fmt.Fprintf(flag.CommandLine.Output(), "Usage: %s [-sni=SNI] destination\n", os.Args[0]) + fmt.Fprintln(flag.CommandLine.Output(), "This tool attempts a TLS connection to the "+ + "destination (port 443), with and without splitting. If the SNI is specified, it "+ + "overrides the destination, which can be an IP address.") + flag.PrintDefaults() + } + + sni := flag.String("sni", "", "Server name override") + flag.Parse() + destination := flag.Arg(0) + if destination == "" { + flag.Usage() + return + } + + addr, err := net.ResolveTCPAddr("tcp", net.JoinHostPort(destination, "443")) + if err != nil { + log.Fatalf("Couldn't resolve destination: %v", err) + } + + if *sni == "" { + *sni = destination + } + tlsConfig := &tls.Config{ServerName: *sni} + + log.Println("Trying direct connection") + conn, err := net.DialTCP(addr.Network(), nil, addr) + if err != nil { + log.Fatalf("Could not establish a TCP connection: %v", err) + } + tlsConn := tls.Client(conn, tlsConfig) + err = tlsConn.Handshake() + if err != nil { + log.Printf("Direct TLS handshake failed: %v", err) + } else { + log.Printf("Direct TLS succeeded") + } + + log.Println("Trying split connection") + splitConn, err := split.DialWithSplit(&net.Dialer{}, addr) + if err != nil { + log.Fatalf("Could not establish a splitting socket: %v", err) + } + tlsConn2 := tls.Client(splitConn, tlsConfig) + err = tlsConn2.Handshake() + if err != nil { + log.Printf("Split TLS handshake failed: %v", err) + } else { + log.Printf("Split TLS succeeded") + } +} diff --git a/Android/tun2socks/intra/split/retrier.go b/Android/tun2socks/intra/split/retrier.go new file mode 100644 index 00000000..da24065a --- /dev/null +++ b/Android/tun2socks/intra/split/retrier.go @@ -0,0 +1,358 @@ +// Copyright 2023 Jigsaw Operations LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package split + +import ( + "errors" + "io" + "math/rand" + "net" + "sync" + "time" + + "github.com/Jigsaw-Code/getsni" +) + +type RetryStats struct { + SNI string // TLS SNI observed, if present. + Bytes int32 // Number of bytes uploaded before the retry. + Chunks int16 // Number of writes before the retry. + Split int16 // Number of bytes in the first retried segment. + Timeout bool // True if the retry was caused by a timeout. +} + +// retrier implements the DuplexConn interface. +type retrier struct { + // mutex is a lock that guards `conn`, `hello`, and `retryCompleteFlag`. + // These fields must not be modified except under this lock. + // After retryCompletedFlag is closed, these values will not be modified + // again so locking is no longer required for reads. + mutex sync.Mutex + dialer *net.Dialer + network string + addr *net.TCPAddr + // conn is the current underlying connection. It is only modified by the reader + // thread, so the reader functions may access it without acquiring a lock. + conn *net.TCPConn + // External read and write deadlines. These need to be stored here so that + // they can be re-applied in the event of a retry. + readDeadline time.Time + writeDeadline time.Time + // Time to wait between the first write and the first read before triggering a + // retry. + timeout time.Duration + // hello is the contents written before the first read. It is initially empty, + // and is cleared when the first byte is received. + hello []byte + // Flag indicating when retry is finished or unnecessary. + retryCompleteFlag chan struct{} + // Flags indicating whether the caller has called CloseRead and CloseWrite. + readCloseFlag chan struct{} + writeCloseFlag chan struct{} + stats *RetryStats +} + +// Helper functions for reading flags. +// In this package, a "flag" is a thread-safe single-use status indicator that +// starts in the "open" state and transitions to "closed" when close() is called. +// It is implemented as a channel over which no data is ever sent. +// Some advantages of this implementation: +// - The language enforces the one-way transition. +// - Nonblocking and blocking access are both straightforward. +// - Checking the status of a closed flag should be extremely fast (although currently +// it's not optimized: https://github.com/golang/go/issues/32529) +func closed(c chan struct{}) bool { + select { + case <-c: + // The channel has been closed. + return true + default: + return false + } +} + +func (r *retrier) readClosed() bool { + return closed(r.readCloseFlag) +} + +func (r *retrier) writeClosed() bool { + return closed(r.writeCloseFlag) +} + +func (r *retrier) retryCompleted() bool { + return closed(r.retryCompleteFlag) +} + +// Given timestamps immediately before and after a successful socket connection +// (i.e. the time the SYN was sent and the time the SYNACK was received), this +// function returns a reasonable timeout for replies to a hello sent on this socket. +func timeout(before, after time.Time) time.Duration { + // These values were chosen to have a <1% false positive rate based on test data. + // False positives trigger an unnecessary retry, which can make connections slower, so they are + // worth avoiding. However, overly long timeouts make retry slower and less useful. + rtt := after.Sub(before) + return 1200*time.Millisecond + 2*rtt +} + +// DefaultTimeout is the value that will cause DialWithSplitRetry to use the system's +// default TCP timeout (typically 2-3 minutes). +const DefaultTimeout time.Duration = 0 + +// DialWithSplitRetry returns a TCP connection that transparently retries by +// splitting the initial upstream segment if the socket closes without receiving a +// reply. Like net.Conn, it is intended for two-threaded use, with one thread calling +// Read and CloseRead, and another calling Write, ReadFrom, and CloseWrite. +// `dialer` will be used to establish the connection. +// `addr` is the destination. +// If `stats` is non-nil, it will be populated with retry-related information. +func DialWithSplitRetry(dialer *net.Dialer, addr *net.TCPAddr, stats *RetryStats) (DuplexConn, error) { + before := time.Now() + conn, err := dialer.Dial(addr.Network(), addr.String()) + if err != nil { + return nil, err + } + after := time.Now() + + if stats == nil { + // This is a fake stats object that will be written but never read. Its purpose + // is to avoid the need for nil checks at each point where stats are updated. + stats = &RetryStats{} + } + + r := &retrier{ + dialer: dialer, + addr: addr, + conn: conn.(*net.TCPConn), + timeout: timeout(before, after), + retryCompleteFlag: make(chan struct{}), + readCloseFlag: make(chan struct{}), + writeCloseFlag: make(chan struct{}), + stats: stats, + } + + return r, nil +} + +// Read-related functions. +func (r *retrier) Read(buf []byte) (n int, err error) { + n, err = r.conn.Read(buf) + if n == 0 && err == nil { + // If no data was read, a nil error doesn't rule out the need for a retry. + return + } + if !r.retryCompleted() { + r.mutex.Lock() + if err != nil { + var neterr net.Error + if errors.As(err, &neterr) { + r.stats.Timeout = neterr.Timeout() + } + // Read failed. Retry. + n, err = r.retry(buf) + } + close(r.retryCompleteFlag) + // Unset read deadline. + r.conn.SetReadDeadline(time.Time{}) + r.hello = nil + r.mutex.Unlock() + } + return +} + +func (r *retrier) retry(buf []byte) (n int, err error) { + r.conn.Close() + var newConn net.Conn + if newConn, err = r.dialer.Dial(r.addr.Network(), r.addr.String()); err != nil { + return + } + r.conn = newConn.(*net.TCPConn) + first, second := splitHello(r.hello) + r.stats.Split = int16(len(first)) + if _, err = r.conn.Write(first); err != nil { + return + } + if _, err = r.conn.Write(second); err != nil { + return + } + // While we were creating the new socket, the caller might have called CloseRead + // or CloseWrite on the old socket. Copy that state to the new socket. + // CloseRead and CloseWrite are idempotent, so this is safe even if the user's + // action actually affected the new socket. + if r.readClosed() { + r.conn.CloseRead() + } + if r.writeClosed() { + r.conn.CloseWrite() + } + // The caller might have set read or write deadlines before the retry. + r.conn.SetReadDeadline(r.readDeadline) + r.conn.SetWriteDeadline(r.writeDeadline) + return r.conn.Read(buf) +} + +func (r *retrier) CloseRead() error { + if !r.readClosed() { + close(r.readCloseFlag) + } + r.mutex.Lock() + defer r.mutex.Unlock() + return r.conn.CloseRead() +} + +func splitHello(hello []byte) ([]byte, []byte) { + if len(hello) == 0 { + return hello, hello + } + const ( + MIN_SPLIT int = 32 + MAX_SPLIT int = 64 + ) + + // Random number in the range [MIN_SPLIT, MAX_SPLIT] + s := MIN_SPLIT + rand.Intn(MAX_SPLIT+1-MIN_SPLIT) + limit := len(hello) / 2 + if s > limit { + s = limit + } + return hello[:s], hello[s:] +} + +// Write-related functions +func (r *retrier) Write(b []byte) (int, error) { + // Double-checked locking pattern. This avoids lock acquisition on + // every packet after retry completes, while also ensuring that r.hello is + // empty at steady-state. + if !r.retryCompleted() { + n := 0 + var err error + attempted := false + r.mutex.Lock() + if !r.retryCompleted() { + n, err = r.conn.Write(b) + attempted = true + r.hello = append(r.hello, b[:n]...) + + r.stats.Chunks++ + r.stats.Bytes = int32(len(r.hello)) + if r.stats.SNI == "" { + r.stats.SNI, _ = getsni.GetSNI(r.hello) + } + + // We require a response or another write within the specified timeout. + r.conn.SetReadDeadline(time.Now().Add(r.timeout)) + } + r.mutex.Unlock() + if attempted { + if err == nil { + return n, nil + } + // A write error occurred on the provisional socket. This should be handled + // by the retry procedure. Block until we have a final socket (which will + // already have replayed b[:n]), and retry. + <-r.retryCompleteFlag + r.mutex.Lock() + r.mutex.Unlock() + m, err := r.conn.Write(b[n:]) + return n + m, err + } + } + + // retryCompleted() is true, so r.conn is final and doesn't need locking. + return r.conn.Write(b) +} + +// Copy one buffer from src to dst, using dst.Write. +func copyOnce(dst io.Writer, src io.Reader) (int64, error) { + // This buffer is large enough to hold any ordinary first write + // without introducing extra splitting. + buf := make([]byte, 2048) + n, err := src.Read(buf) + if err != nil { + return 0, err + } + n, err = dst.Write(buf[:n]) + return int64(n), err +} + +func (r *retrier) ReadFrom(reader io.Reader) (bytes int64, err error) { + for !r.retryCompleted() { + if bytes, err = copyOnce(r, reader); err != nil { + return + } + } + + var b int64 + b, err = r.conn.ReadFrom(reader) + bytes += b + return +} + +func (r *retrier) CloseWrite() error { + if !r.writeClosed() { + close(r.writeCloseFlag) + } + r.mutex.Lock() + defer r.mutex.Unlock() + return r.conn.CloseWrite() +} + +func (r *retrier) Close() error { + if err := r.CloseWrite(); err != nil { + return err + } + return r.CloseRead() +} + +// LocalAddr behaves slightly strangely: its value may change as a +// result of a retry. However, LocalAddr is largely useless for +// TCP client sockets anyway, so nothing should be relying on this. +func (r *retrier) LocalAddr() net.Addr { + r.mutex.Lock() + defer r.mutex.Unlock() + return r.conn.LocalAddr() +} + +func (r *retrier) RemoteAddr() net.Addr { + return r.addr +} + +func (r *retrier) SetReadDeadline(t time.Time) error { + r.mutex.Lock() + defer r.mutex.Unlock() + r.readDeadline = t + // Don't enforce read deadlines until after the retry + // is complete. Retry relies on setting its own read + // deadline, and we don't want this to interfere. + if r.retryCompleted() { + return r.conn.SetReadDeadline(t) + } + return nil +} + +func (r *retrier) SetWriteDeadline(t time.Time) error { + r.mutex.Lock() + defer r.mutex.Unlock() + r.writeDeadline = t + return r.conn.SetWriteDeadline(t) +} + +func (r *retrier) SetDeadline(t time.Time) error { + e1 := r.SetReadDeadline(t) + e2 := r.SetWriteDeadline(t) + if e1 != nil { + return e1 + } + return e2 +} diff --git a/Android/tun2socks/intra/split/retrier_test.go b/Android/tun2socks/intra/split/retrier_test.go new file mode 100644 index 00000000..d539c2e9 --- /dev/null +++ b/Android/tun2socks/intra/split/retrier_test.go @@ -0,0 +1,307 @@ +// Copyright 2023 Jigsaw Operations LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package split + +import ( + "bytes" + "io" + "net" + "testing" + "time" +) + +type setup struct { + t *testing.T + server *net.TCPListener + clientSide DuplexConn + serverSide *net.TCPConn + serverReceived []byte + stats *RetryStats +} + +func makeSetup(t *testing.T) *setup { + addr, err := net.ResolveTCPAddr("tcp", ":0") + if err != nil { + t.Error(err) + } + server, err := net.ListenTCP("tcp", addr) + if err != nil { + t.Error(err) + } + + serverAddr, ok := server.Addr().(*net.TCPAddr) + if !ok { + t.Error("Server isn't TCP?") + } + var stats RetryStats + clientSide, err := DialWithSplitRetry(&net.Dialer{}, serverAddr, &stats) + if err != nil { + t.Error(err) + } + serverSide, err := server.AcceptTCP() + if err != nil { + t.Error(err) + } + return &setup{t, server, clientSide, serverSide, nil, &stats} +} + +const BUFSIZE = 256 + +func makeBuffer() []byte { + buffer := make([]byte, BUFSIZE) + for i := 0; i < BUFSIZE; i++ { + buffer[i] = byte(i) + } + return buffer +} + +func send(src io.Writer, dest io.Reader, t *testing.T) []byte { + buffer := makeBuffer() + n, err := src.Write(buffer) + if err != nil { + t.Error(err) + } + if n != len(buffer) { + t.Errorf("Failed to write whole buffer: %d", n) + } + + buf := make([]byte, len(buffer)) + n, err = dest.Read(buf) + if err != nil { + t.Error(nil) + } + if n != len(buf) { + t.Errorf("Not enough bytes: %d", n) + } + if !bytes.Equal(buf, buffer) { + t.Errorf("Wrong contents") + } + return buf +} + +func (s *setup) sendUp() { + buf := send(s.clientSide, s.serverSide, s.t) + s.serverReceived = append(s.serverReceived, buf...) +} + +func (s *setup) sendDown() { + send(s.serverSide, s.clientSide, s.t) +} + +func closeRead(closed, blocked DuplexConn, t *testing.T) { + closed.CloseRead() + // TODO: Figure out if this is detectable on the opposite side. +} + +func closeWrite(closed, blocked DuplexConn, t *testing.T) { + closed.CloseWrite() + n, err := blocked.Read(make([]byte, 1)) + if err != io.EOF || n > 0 { + t.Errorf("Read should have failed with EOF") + } +} + +func (s *setup) closeReadUp() { + closeRead(s.clientSide, s.serverSide, s.t) +} + +func (s *setup) closeWriteUp() { + closeWrite(s.clientSide, s.serverSide, s.t) +} + +func (s *setup) closeReadDown() { + closeRead(s.serverSide, s.clientSide, s.t) +} + +func (s *setup) closeWriteDown() { + closeWrite(s.serverSide, s.clientSide, s.t) +} + +func (s *setup) close() { + s.server.Close() +} + +func (s *setup) confirmRetry() { + done := make(chan struct{}) + go func() { + buf := make([]byte, len(s.serverReceived)) + n, err := s.clientSide.Read(buf) + if err != nil { + s.t.Error(err) + } + if n != len(buf) { + s.t.Error("Unexpected echo length") + } + close(done) + }() + + var err error + s.serverSide, err = s.server.AcceptTCP() + if err != nil { + s.t.Errorf("Second socket failed") + } + buf := make([]byte, len(s.serverReceived)) + var n int + for n < len(buf) { + var m int + m, err = s.serverSide.Read(buf[n:]) + n += m + if err != nil { + s.t.Error(err) + } + } + if !bytes.Equal(buf, s.serverReceived) { + s.t.Errorf("Replay was corrupted") + } + + n, err = s.serverSide.Write(buf) + if err != nil { + s.t.Error(err) + } + if n != len(buf) { + s.t.Errorf("Couldn't echo all bytes: %d", n) + } + <-done +} + +func (s *setup) checkNoSplit() { + if s.stats.Split > 0 { + s.t.Error("Retry should not have occurred") + } +} + +func (s *setup) checkStats(bytes int32, chunks int16, timeout bool) { + r := s.stats + if r.Bytes != bytes { + s.t.Errorf("Expected %d bytes, got %d", bytes, r.Bytes) + } + if r.Chunks != chunks { + s.t.Errorf("Expected %d chunks, got %d", chunks, r.Chunks) + } + if r.Timeout != timeout { + s.t.Errorf("Expected timeout to be %t", timeout) + } + if r.Split < 32 || r.Split > 64 { + s.t.Errorf("Unexpected split: %d", r.Split) + } +} + +func TestNormalConnection(t *testing.T) { + s := makeSetup(t) + s.sendUp() + s.sendDown() + s.closeReadUp() + s.closeWriteUp() + s.close() + s.checkNoSplit() +} + +func TestFinRetry(t *testing.T) { + s := makeSetup(t) + s.sendUp() + s.serverSide.Close() + s.confirmRetry() + s.sendDown() + s.closeReadUp() + s.closeWriteUp() + s.close() + s.checkStats(BUFSIZE, 1, false) +} + +func TestTimeoutRetry(t *testing.T) { + s := makeSetup(t) + s.sendUp() + // Client should time out and retry after about 1.2 seconds + time.Sleep(2 * time.Second) + s.confirmRetry() + s.sendDown() + s.closeReadUp() + s.closeWriteUp() + s.close() + s.checkStats(BUFSIZE, 1, true) +} + +func TestTwoWriteRetry(t *testing.T) { + s := makeSetup(t) + s.sendUp() + s.sendUp() + s.serverSide.Close() + s.confirmRetry() + s.sendDown() + s.closeReadUp() + s.closeWriteUp() + s.close() + s.checkStats(2*BUFSIZE, 2, false) +} + +func TestFailedRetry(t *testing.T) { + s := makeSetup(t) + s.sendUp() + s.serverSide.Close() + s.confirmRetry() + s.closeReadDown() + s.closeWriteDown() + s.close() + s.checkStats(BUFSIZE, 1, false) +} + +func TestDisappearingServer(t *testing.T) { + s := makeSetup(t) + s.sendUp() + s.close() + s.serverSide.Close() + // Try to read 1 byte to trigger the retry. + n, err := s.clientSide.Read(make([]byte, 1)) + if n > 0 || err == nil { + t.Error("Expected read to fail") + } + s.clientSide.CloseRead() + s.clientSide.CloseWrite() + s.checkNoSplit() +} + +func TestSequentialClose(t *testing.T) { + s := makeSetup(t) + s.sendUp() + s.closeWriteUp() + s.sendDown() + s.closeWriteDown() + s.close() + s.checkNoSplit() +} + +func TestBackwardsUse(t *testing.T) { + s := makeSetup(t) + s.sendDown() + s.closeWriteDown() + s.sendUp() + s.closeWriteUp() + s.close() + s.checkNoSplit() +} + +// Regression test for an issue in which the initial handshake timeout +// continued to apply after the handshake completed. +func TestIdle(t *testing.T) { + s := makeSetup(t) + s.sendUp() + s.sendDown() + // Wait for longer than the 1.2-second response timeout + time.Sleep(2 * time.Second) + // Try to send down some more data. + s.sendDown() + s.close() + s.checkNoSplit() +} diff --git a/Android/tun2socks/intra/stream_dialer.go b/Android/tun2socks/intra/stream_dialer.go new file mode 100644 index 00000000..aee50b13 --- /dev/null +++ b/Android/tun2socks/intra/stream_dialer.go @@ -0,0 +1,142 @@ +// Copyright 2023 Jigsaw Operations LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package intra + +import ( + "context" + "errors" + "fmt" + "net" + "net/netip" + "sync/atomic" + "time" + + "github.com/Jigsaw-Code/Intra/Android/tun2socks/intra/doh" + "github.com/Jigsaw-Code/Intra/Android/tun2socks/intra/protect" + "github.com/Jigsaw-Code/Intra/Android/tun2socks/intra/split" + "github.com/Jigsaw-Code/outline-sdk/transport" +) + +type intraStreamDialer struct { + fakeDNSAddr netip.AddrPort + dns atomic.Pointer[doh.Transport] + dialer *net.Dialer + alwaysSplitHTTPS atomic.Bool + listener TCPListener + sniReporter *tcpSNIReporter +} + +var _ transport.StreamDialer = (*intraStreamDialer)(nil) + +func newIntraStreamDialer( + fakeDNS netip.AddrPort, + dns doh.Transport, + protector protect.Protector, + listener TCPListener, + sniReporter *tcpSNIReporter, +) (*intraStreamDialer, error) { + if dns == nil { + return nil, errors.New("dns is required") + } + + dohsd := &intraStreamDialer{ + fakeDNSAddr: fakeDNS, + dialer: protect.MakeDialer(protector), + listener: listener, + sniReporter: sniReporter, + } + dohsd.dns.Store(&dns) + return dohsd, nil +} + +// Dial implements StreamDialer.Dial. +func (sd *intraStreamDialer) Dial(ctx context.Context, raddr string) (transport.StreamConn, error) { + dest, err := netip.ParseAddrPort(raddr) + if err != nil { + return nil, fmt.Errorf("invalid raddr (%v): %w", raddr, err) + } + + if isEquivalentAddrPort(dest, sd.fakeDNSAddr) { + src, dst := net.Pipe() + go doh.Accept(*sd.dns.Load(), dst) + return newStreamConnFromPipeConns(src, dst) + } + + stats := makeTCPSocketSummary(dest) + beforeConn := time.Now() + conn, err := sd.dial(ctx, dest, stats) + if err != nil { + return nil, fmt.Errorf("failed to dial to target: %w", err) + } + stats.Synack = int32(time.Since(beforeConn).Milliseconds()) + + return makeTCPWrapConn(conn, stats, sd.listener, sd.sniReporter), nil +} + +func (sd *intraStreamDialer) SetDNS(dns doh.Transport) error { + if dns == nil { + return errors.New("dns is required") + } + sd.dns.Store(&dns) + return nil +} + +func (sd *intraStreamDialer) dial(ctx context.Context, dest netip.AddrPort, stats *TCPSocketSummary) (transport.StreamConn, error) { + if dest.Port() == 443 { + if sd.alwaysSplitHTTPS.Load() { + return split.DialWithSplit(sd.dialer, net.TCPAddrFromAddrPort(dest)) + } else { + stats.Retry = &split.RetryStats{} + return split.DialWithSplitRetry(sd.dialer, net.TCPAddrFromAddrPort(dest), stats.Retry) + } + } else { + tcpsd := &transport.TCPStreamDialer{ + Dialer: *sd.dialer, + } + return tcpsd.Dial(ctx, dest.String()) + } +} + +// transport.StreamConn wrapper around net.Pipe call + +type pipeconn struct { + net.Conn + remote net.Conn +} + +var _ transport.StreamConn = (*pipeconn)(nil) + +// newStreamConnFromPipeConns creates a new [transport.StreamConn] that wraps around the local [net.Conn]. +// The remote [net.Conn] will be closed when you call CloseRead() on the returned [transport.StreamConn] +func newStreamConnFromPipeConns(local, remote net.Conn) (transport.StreamConn, error) { + if local == nil || remote == nil { + return nil, errors.New("local conn and remote conn are required") + } + return &pipeconn{local, remote}, nil +} + +func (c *pipeconn) Close() error { + return errors.Join(c.CloseRead(), c.CloseWrite()) +} + +// CloseRead makes sure all read on the local conn returns io.EOF, and write on the remote conn returns ErrClosedPipe. +func (c *pipeconn) CloseRead() error { + return c.remote.Close() +} + +// CloseWrite makes sure all read on the remote conn returns io.EOF, and write on the local conn returns ErrClosedPipe. +func (c *pipeconn) CloseWrite() error { + return c.Conn.Close() +} diff --git a/Android/tun2socks/intra/tcp.go b/Android/tun2socks/intra/tcp.go new file mode 100644 index 00000000..7b252364 --- /dev/null +++ b/Android/tun2socks/intra/tcp.go @@ -0,0 +1,144 @@ +// Copyright 2023 Jigsaw Operations LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Derived from go-tun2socks's "direct" handler under the Apache 2.0 license. + +package intra + +import ( + "io" + "net/netip" + "sync" + "sync/atomic" + "time" + + "github.com/Jigsaw-Code/Intra/Android/tun2socks/intra/split" + "github.com/Jigsaw-Code/outline-sdk/transport" +) + +// TCPSocketSummary provides information about each TCP socket, reported when it is closed. +type TCPSocketSummary struct { + DownloadBytes int64 // Total bytes downloaded. + UploadBytes int64 // Total bytes uploaded. + Duration int32 // Duration in seconds. + ServerPort int16 // The server port. All values except 80, 443, and 0 are set to -1. + Synack int32 // TCP handshake latency (ms) + // Retry is non-nil if retry was possible. Retry.Split is non-zero if a retry occurred. + Retry *split.RetryStats +} + +func makeTCPSocketSummary(dest netip.AddrPort) *TCPSocketSummary { + stats := &TCPSocketSummary{ + ServerPort: int16(dest.Port()), + } + if stats.ServerPort != 0 && stats.ServerPort != 80 && stats.ServerPort != 443 { + stats.ServerPort = -1 + } + return stats +} + +// TCPListener is notified when a socket closes. +type TCPListener interface { + OnTCPSocketClosed(*TCPSocketSummary) +} + +type tcpWrapConn struct { + transport.StreamConn + + wg *sync.WaitGroup + rDone, wDone atomic.Bool + + beginTime time.Time + stats *TCPSocketSummary + + listener TCPListener + sniReporter *tcpSNIReporter +} + +func makeTCPWrapConn(c transport.StreamConn, stats *TCPSocketSummary, listener TCPListener, sniReporter *tcpSNIReporter) (conn *tcpWrapConn) { + conn = &tcpWrapConn{ + StreamConn: c, + wg: &sync.WaitGroup{}, + beginTime: time.Now(), + stats: stats, + listener: listener, + sniReporter: sniReporter, + } + + // Wait until both read and write are done + conn.wg.Add(2) + go func() { + conn.wg.Wait() + conn.stats.Duration = int32(time.Since(conn.beginTime)) + if conn.listener != nil { + conn.listener.OnTCPSocketClosed(conn.stats) + } + if conn.stats.Retry != nil && conn.sniReporter != nil { + conn.sniReporter.Report(*conn.stats) + } + }() + + return +} + +func (conn *tcpWrapConn) Close() error { + defer conn.close(&conn.wDone) + defer conn.close(&conn.rDone) + return conn.StreamConn.Close() +} + +func (conn *tcpWrapConn) CloseRead() error { + defer conn.close(&conn.rDone) + return conn.StreamConn.CloseRead() +} + +func (conn *tcpWrapConn) CloseWrite() error { + defer conn.close(&conn.wDone) + return conn.StreamConn.CloseWrite() +} + +func (conn *tcpWrapConn) Read(b []byte) (n int, err error) { + defer func() { + conn.stats.DownloadBytes += int64(n) + }() + return conn.StreamConn.Read(b) +} + +func (conn *tcpWrapConn) WriteTo(w io.Writer) (n int64, err error) { + defer func() { + conn.stats.DownloadBytes += n + }() + return io.Copy(w, conn.StreamConn) +} + +func (conn *tcpWrapConn) Write(b []byte) (n int, err error) { + defer func() { + conn.stats.UploadBytes += int64(n) + }() + return conn.StreamConn.Write(b) +} + +func (conn *tcpWrapConn) ReadFrom(r io.Reader) (n int64, err error) { + defer func() { + conn.stats.UploadBytes += n + }() + return io.Copy(conn.StreamConn, r) +} + +func (conn *tcpWrapConn) close(done *atomic.Bool) { + // make sure conn.wg is being called at most once for a specific `done` flag + if done.CompareAndSwap(false, true) { + conn.wg.Done() + } +} diff --git a/Android/tun2socks/intra/tunnel.go b/Android/tun2socks/intra/tunnel.go new file mode 100644 index 00000000..ae7abcc6 --- /dev/null +++ b/Android/tun2socks/intra/tunnel.go @@ -0,0 +1,123 @@ +// Copyright 2023 Jigsaw Operations LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package intra + +import ( + "errors" + "fmt" + "io" + "net" + "os" + "strings" + + "github.com/Jigsaw-Code/Intra/Android/tun2socks/intra/doh" + "github.com/Jigsaw-Code/Intra/Android/tun2socks/intra/protect" + "github.com/Jigsaw-Code/outline-sdk/network" + "github.com/Jigsaw-Code/outline-sdk/network/lwip2transport" +) + +// Listener receives usage statistics when a UDP or TCP socket is closed, +// or a DNS query is completed. +type Listener interface { + UDPListener + TCPListener + doh.Listener +} + +// Tunnel represents an Intra session. +type Tunnel struct { + network.IPDevice + + sd *intraStreamDialer + pp *intraPacketProxy + sni *tcpSNIReporter + tun io.Closer +} + +// NewTunnel creates a connected Intra session. +// +// `fakedns` is the DNS server (IP and port) that will be used by apps on the TUN device. +// +// This will normally be a reserved or remote IP address, port 53. +// +// `udpdns` and `tcpdns` are the actual location of the DNS server in use. +// +// These will normally be localhost with a high-numbered port. +// +// `dohdns` is the initial DOH transport. +// `eventListener` will be notified at the completion of every tunneled socket. +func NewTunnel( + fakedns string, dohdns doh.Transport, tun io.Closer, protector protect.Protector, eventListener Listener, +) (t *Tunnel, err error) { + if eventListener == nil { + return nil, errors.New("eventListener is required") + } + + fakeDNSAddr, err := net.ResolveUDPAddr("udp", fakedns) + if err != nil { + return nil, fmt.Errorf("failed to resolve fakedns: %w", err) + } + + t = &Tunnel{ + sni: &tcpSNIReporter{ + dns: dohdns, + }, + tun: tun, + } + + t.sd, err = newIntraStreamDialer(fakeDNSAddr.AddrPort(), dohdns, protector, eventListener, t.sni) + if err != nil { + return nil, fmt.Errorf("failed to create stream dialer: %w", err) + } + + t.pp, err = newIntraPacketProxy(fakeDNSAddr.AddrPort(), dohdns, protector, eventListener) + if err != nil { + return nil, fmt.Errorf("failed to create packet proxy: %w", err) + } + + if t.IPDevice, err = lwip2transport.ConfigureDevice(t.sd, t.pp); err != nil { + return nil, fmt.Errorf("failed to configure lwIP stack: %w", err) + } + + t.SetDNS(dohdns) + return +} + +// Set the DNSTransport. This method must be called before connecting the transport +// to the TUN device. The transport can be changed at any time during operation, but +// must not be nil. +func (t *Tunnel) SetDNS(dns doh.Transport) { + t.sd.SetDNS(dns) + t.pp.SetDNS(dns) + t.sni.SetDNS(dns) +} + +// Enable reporting of SNIs that resulted in connection failures, using the +// Choir library for privacy-preserving error reports. `file` is the path +// that Choir should use to store its persistent state, `suffix` is the +// authoritative domain to which reports will be sent, and `country` is a +// two-letter ISO country code for the user's current location. +func (t *Tunnel) EnableSNIReporter(filename, suffix, country string) error { + f, err := os.OpenFile(filename, os.O_RDWR|os.O_CREATE, 0600) + if err != nil { + return err + } + return t.sni.Configure(f, suffix, strings.ToLower(country)) +} + +func (t *Tunnel) Disconnect() { + t.Close() + t.tun.Close() +} diff --git a/Android/tun2socks/intra/udp.go b/Android/tun2socks/intra/udp.go new file mode 100644 index 00000000..f3542fca --- /dev/null +++ b/Android/tun2socks/intra/udp.go @@ -0,0 +1,46 @@ +// Copyright 2023 Jigsaw Operations LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Derived from go-tun2socks's "direct" handler under the Apache 2.0 license. + +package intra + +import ( + "sync/atomic" + "time" +) + +// UDPSocketSummary describes a non-DNS UDP association, reported when it is discarded. +type UDPSocketSummary struct { + UploadBytes int64 // Amount uploaded (bytes) + DownloadBytes int64 // Amount downloaded (bytes) + Duration int32 // How long the socket was open (seconds) +} + +// UDPListener is notified when a non-DNS UDP association is discarded. +type UDPListener interface { + OnUDPSocketClosed(*UDPSocketSummary) +} + +type tracker struct { + start time.Time + upload atomic.Int64 // Non-DNS upload bytes + download atomic.Int64 // Non-DNS download bytes +} + +func makeTracker() *tracker { + return &tracker{ + start: time.Now(), + } +} diff --git a/Android/tun2socks/src/main/AndroidManifest.xml b/Android/tun2socks/src/main/AndroidManifest.xml new file mode 100644 index 00000000..788b4795 --- /dev/null +++ b/Android/tun2socks/src/main/AndroidManifest.xml @@ -0,0 +1,3 @@ + + + diff --git a/Android/tun2socks/tun2socks.aar b/Android/tun2socks/tun2socks.aar deleted file mode 100644 index c5e585b5..00000000 Binary files a/Android/tun2socks/tun2socks.aar and /dev/null differ diff --git a/go.mod b/go.mod new file mode 100644 index 00000000..d9b8bd37 --- /dev/null +++ b/go.mod @@ -0,0 +1,20 @@ +module github.com/Jigsaw-Code/Intra + +go 1.21.1 + +require ( + github.com/Jigsaw-Code/choir v1.0.1 + github.com/Jigsaw-Code/getsni v1.0.0 + github.com/Jigsaw-Code/outline-sdk v0.0.7 + github.com/eycorsican/go-tun2socks v1.16.11 + golang.org/x/mobile v0.0.0-20231006135142-2b44d11868fe + golang.org/x/net v0.16.0 + golang.org/x/sys v0.13.0 +) + +require ( + golang.org/x/crypto v0.14.0 // indirect + golang.org/x/mod v0.13.0 // indirect + golang.org/x/sync v0.4.0 // indirect + golang.org/x/tools v0.14.0 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 00000000..166b2e17 --- /dev/null +++ b/go.sum @@ -0,0 +1,39 @@ +github.com/Jigsaw-Code/choir v1.0.1 h1:WeRt6aTn5L+MtRNqRJ+J1RKgoO8CyXXt1dtZghy2KjE= +github.com/Jigsaw-Code/choir v1.0.1/go.mod h1:c4Wd1y1PeCajZbKZV+ZmcFGMDoduyqMCEMHW5iqzWXI= +github.com/Jigsaw-Code/getsni v1.0.0 h1:OUTIu7wTBi/7DMX+RkZrN7XhU3UDevTEsAWK4gsqSwE= +github.com/Jigsaw-Code/getsni v1.0.0/go.mod h1:Ps0Ec3fVMKLyAItVbMKoQFq1lDjtFQXZ+G5nRNNh/QE= +github.com/Jigsaw-Code/outline-sdk v0.0.7 h1:WlFaV1tFpIQ/pflrKwrQuNIP3kJpgh7yJuqiTb54sGA= +github.com/Jigsaw-Code/outline-sdk v0.0.7/go.mod h1:hhlKz0+r9wSDFT8usvN8Zv/BFToCIFAUn1P2Qk8G2CM= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/eycorsican/go-tun2socks v1.16.11 h1:+hJDNgisrYaGEqoSxhdikMgMJ4Ilfwm/IZDrWRrbaH8= +github.com/eycorsican/go-tun2socks v1.16.11/go.mod h1:wgB2BFT8ZaPKyKOQ/5dljMG/YIow+AIXyq4KBwJ5sGQ= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/songgao/water v0.0.0-20190725173103-fd331bda3f4b/go.mod h1:P5HUIBuIWKbyjl083/loAegFkfbFNx5i2qEP4CNbm7E= +github.com/stretchr/testify v1.8.2 h1:+h33VjcLVPDHtOdpUCuF+7gSuG3yGIftsP1YvFihtJ8= +github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20200604202706-70a84ac30bf9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= +golang.org/x/crypto v0.14.0 h1:wBqGXzWJW6m1XrIKlAH0Hs1JJ7+9KBwnIO8v66Q9cHc= +golang.org/x/crypto v0.14.0/go.mod h1:MVFd36DqK4CsrnJYDkBA3VC4m2GkXAM0PvzMCn4JQf4= +golang.org/x/mobile v0.0.0-20231006135142-2b44d11868fe h1:lrXv4yHeD9FA8PSJATWowP1QvexpyAPWmPia+Kbzql8= +golang.org/x/mobile v0.0.0-20231006135142-2b44d11868fe/go.mod h1:BrnXpEObnFxpaT75Jo9hsCazwOWcp7nVIa8NNuH5cuA= +golang.org/x/mod v0.13.0 h1:I/DsJXRlw/8l/0c24sM9yb0T4z9liZTduXvdAWYiysY= +golang.org/x/mod v0.13.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= +golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20191021144547-ec77196f6094/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.16.0 h1:7eBu7KsSvFDtSXUIDbh3aqlK4DPsZ1rByC8PFfBThos= +golang.org/x/net v0.16.0/go.mod h1:NxSsAGuq816PNPmqtQdLE42eU2Fs7NoRIZrHJAlaCOE= +golang.org/x/sync v0.4.0 h1:zxkM55ReGkDlKSM+Fu41A+zmbZuaPVbGMzvvdUPznYQ= +golang.org/x/sync v0.4.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.13.0 h1:Af8nKPmuFypiUBjVoU9V20FiaFXOcuZI21p0ycVYYGE= +golang.org/x/sys v0.13.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/tools v0.14.0 h1:jvNa2pY0M4r62jkRQ6RwEZZyPcymeL9XZMLBbV7U2nc= +golang.org/x/tools v0.14.0/go.mod h1:uYBEerGOWcJyEORxN+Ek8+TT266gXkNlHdJBwexUsBg= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/tools.go b/tools.go new file mode 100644 index 00000000..6b45bed1 --- /dev/null +++ b/tools.go @@ -0,0 +1,25 @@ +// Copyright 2023 Jigsaw Operations LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//go:build tools +// +build tools + +// See https://github.com/golang/go/wiki/Modules#how-can-i-track-tool-dependencies-for-a-module + +package tools + +import ( + _ "golang.org/x/mobile/cmd/gobind" + _ "golang.org/x/mobile/cmd/gomobile" +)