-
Notifications
You must be signed in to change notification settings - Fork 0
/
DnsTlsSocket.h
208 lines (172 loc) · 8.79 KB
/
DnsTlsSocket.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
/*
* Copyright (C) 2018 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef _DNS_DNSTLSSOCKET_H
#define _DNS_DNSTLSSOCKET_H
#include <openssl/ssl.h>
#include <future>
#include <mutex>
#include <android-base/thread_annotations.h>
#include <android-base/unique_fd.h>
#include <netdutils/Slice.h>
#include <netdutils/Status.h>
#include "DnsTlsServer.h"
#include "IDnsTlsSocket.h"
#include "LockedQueue.h"
namespace android {
namespace net {
class IDnsTlsSocketObserver;
class DnsTlsSessionCache;
// A class for managing a TLS socket that sends and receives messages in
// [length][value] format, with a 2-byte length (i.e. DNS-over-TCP format).
// This class is not aware of query-response pairing or anything else about DNS.
// For the observer:
// This class is not re-entrant: the observer is not permitted to wait for a call to query()
// or the destructor in a callback. Doing so will result in deadlocks.
// This class may call the observer at any time after initialize(), until the destructor
// returns (but not after).
//
// Calls to IDnsTlsSocketObserver in a DnsTlsSocket life cycle:
//
// UNINITIALIZED
// |
// v
// INITIALIZED
// |
// v
// +----CONNECTING------+
// Handshake fails | | Handshake succeeds
// (onClose() when | |
// mAsyncHandshake is set) | v
// | +---> CONNECTED --+
// | | | |
// | +-----------+ | Idle timeout
// | Send/Recv queries | onClose()
// | onResponse() |
// | |
// | |
// +--> WAIT_FOR_DELETE <-----+
//
//
// TODO: Add onHandshakeFinished() for handshake results.
class DnsTlsSocket : public IDnsTlsSocket {
public:
enum class State {
UNINITIALIZED,
INITIALIZED,
CONNECTING,
CONNECTED,
WAIT_FOR_DELETE,
};
DnsTlsSocket(const DnsTlsServer& server, unsigned mark,
IDnsTlsSocketObserver* _Nonnull observer, DnsTlsSessionCache* _Nonnull cache)
: mMark(mark), mServer(server), mObserver(observer), mCache(cache) {}
~DnsTlsSocket();
// Creates the SSL context for this session. Returns false on failure.
// This method should be called after construction and before use of a DnsTlsSocket.
// Only call this method once per DnsTlsSocket.
bool initialize() EXCLUDES(mLock);
// If async handshake is enabled, this function simply signals a handshake request, and the
// handshake will be performed in the loop thread; otherwise, if async handshake is disabled,
// this function performs the handshake and returns after the handshake finishes.
bool startHandshake() EXCLUDES(mLock);
// Send a query on the provided SSL socket. |query| contains
// the body of a query, not including the ID header. This function will typically return before
// the query is actually sent. If this function fails, DnsTlsSocketObserver will be
// notified that the socket is closed.
// Note that success here indicates successful sending, not receipt of a response.
// Thread-safe.
bool query(uint16_t id, const netdutils::Slice query) override EXCLUDES(mLock);
private:
// Lock to be held by the SSL event loop thread. This is not normally in contention.
std::mutex mLock;
// Forwards queries and receives responses. Blocks until the idle timeout.
void loop() EXCLUDES(mLock);
std::unique_ptr<std::thread> mLoopThread GUARDED_BY(mLock);
// On success, sets mSslFd to a socket connected to mAddr (the
// connection will likely be in progress if mProtocol is IPPROTO_TCP).
// On error, returns the errno.
netdutils::Status tcpConnect() REQUIRES(mLock);
bssl::UniquePtr<SSL> prepareForSslConnect(int fd) REQUIRES(mLock);
// Connect an SSL session on the provided socket. If connection fails, closing the
// socket remains the caller's responsibility.
bssl::UniquePtr<SSL> sslConnect(int fd) REQUIRES(mLock);
// Connect an SSL session on the provided socket. This is an interruptible version
// which allows to terminate connection handshake any time.
bssl::UniquePtr<SSL> sslConnectV2(int fd) REQUIRES(mLock);
// Disconnect the SSL session and close the socket.
void sslDisconnect() REQUIRES(mLock);
// Writes a buffer to the socket.
bool sslWrite(const netdutils::Slice buffer) REQUIRES(mLock);
// Reads exactly the specified number of bytes from the socket, or fails.
// Returns SSL_ERROR_NONE on success.
// If |wait| is true, then this function always blocks. Otherwise, it
// will return SSL_ERROR_WANT_READ if there is no data from the server to read.
int sslRead(const netdutils::Slice buffer, bool wait) REQUIRES(mLock);
bool sendQuery(const std::vector<uint8_t>& buf) REQUIRES(mLock);
// Read one DNS response. It can potentially block until reading the exact bytes of
// the response.
bool readResponse() REQUIRES(mLock);
// It is only used for DNS-OVER-TLS internal test.
bool setTestCaCertificate() REQUIRES(mLock);
// Similar to query(), this function uses incrementEventFd to send a message to the
// loop thread. However, instead of incrementing the counter by one (indicating a
// new query), it wraps the counter to negative, which we use to indicate a shutdown
// request.
void requestLoopShutdown() EXCLUDES(mLock);
// This function sends a message to the loop thread by incrementing mEventFd.
bool incrementEventFd(int64_t count) EXCLUDES(mLock);
// Transition the state from expected state |from| to new state |to|.
void transitionState(State from, State to) REQUIRES(mLock);
// Queue of pending queries. query() pushes items onto the queue and notifies
// the loop thread by incrementing mEventFd. loop() reads items off the queue.
LockedQueue<std::vector<uint8_t>> mQueue;
// eventfd socket used for notifying the SSL thread when queries are ready to send.
// This socket acts similarly to an atomic counter, incremented by query() and cleared
// by loop(). We have to use a socket because the SSL thread needs to wait in poll()
// for input from either a remote server or a query thread. Since eventfd does not have
// EOF, we indicate a close request by setting the counter to a negative number.
// This file descriptor is opened by initialize(), and closed implicitly after
// destruction.
// Note that: data starts being read from the eventfd when the state is CONNECTED.
base::unique_fd mEventFd;
// An eventfd used to listen to shutdown requests when the state is CONNECTING.
// TODO: let |mEventFd| exclusively handle query requests, and let |mShutdownEvent| exclusively
// handle shutdown requests.
base::unique_fd mShutdownEvent;
// SSL Socket fields.
bssl::UniquePtr<SSL_CTX> mSslCtx GUARDED_BY(mLock);
base::unique_fd mSslFd GUARDED_BY(mLock);
bssl::UniquePtr<SSL> mSsl GUARDED_BY(mLock);
static constexpr std::chrono::seconds kIdleTimeout = std::chrono::seconds(20);
const unsigned mMark; // Socket mark
const DnsTlsServer mServer;
IDnsTlsSocketObserver* _Nonnull const mObserver;
DnsTlsSessionCache* _Nonnull const mCache;
State mState GUARDED_BY(mLock) = State::UNINITIALIZED;
// If true, defer the handshake to the loop thread; otherwise, run the handshake on caller's
// thread (the call to startHandshake()).
bool mAsyncHandshake GUARDED_BY(mLock) = false;
// The time to wait for the attempt on connecting to the server.
// Set the default value 127 seconds to be consistent with TCP connect timeout.
// (presume net.ipv4.tcp_syn_retries = 6)
static constexpr int kDotConnectTimeoutMs = 127 * 1000;
int mConnectTimeoutMs;
// For testing.
friend class DnsTlsSocketTest;
};
} // end of namespace net
} // end of namespace android
#endif // _DNS_DNSTLSSOCKET_H