Skip to content

Commit

Permalink
Merge branch 'main' into sam/add-netp-subscription-auth-support
Browse files Browse the repository at this point in the history
* main:
  Update latency & tunnel failure monitors implementation (#613)
  Prevents VPNSettings from reporting fake changes (#614)
  Update Link Tracking Protection to preserve headers (#600)
  • Loading branch information
samsymons committed Dec 24, 2023
2 parents a2e41db + 18043cb commit 8dd2ee0
Show file tree
Hide file tree
Showing 7 changed files with 247 additions and 386 deletions.
37 changes: 27 additions & 10 deletions Sources/BrowserServicesKit/LinkProtection/LinkProtection.swift
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,26 @@ public struct LinkProtection {
errorReporting: errorReporting)
}

private func makeNewRequest(changingUrl url: URL, inRequest request: URLRequest) -> URLRequest {
var newRequest = request
newRequest.url = url
return newRequest
}

public mutating func setMainFrameUrl(_ url: URL?) {
mainFrameUrl = url
}

public func getCleanURLRequest(from urlRequest: URLRequest,
onStartExtracting: () -> Void,
onFinishExtracting: @escaping () -> Void,
completion: @escaping (URLRequest) -> Void) {
getCleanURL(from: urlRequest.url!, onStartExtracting: onStartExtracting, onFinishExtracting: onFinishExtracting) { newUrl in
let newRequest = makeNewRequest(changingUrl: newUrl, inRequest: urlRequest)
completion(newRequest)
}
}

public func getCleanURL(from url: URL,
onStartExtracting: () -> Void,
onFinishExtracting: @escaping () -> Void,
Expand Down Expand Up @@ -77,11 +93,12 @@ public struct LinkProtection {

// swiftlint:disable function_parameter_count
public func requestTrackingLinkRewrite(initiatingURL: URL?,
destinationURL: URL,
destinationRequest: URLRequest,
onStartExtracting: () -> Void,
onFinishExtracting: @escaping () -> Void,
onLinkRewrite: @escaping (URL) -> Void,
onLinkRewrite: @escaping (URLRequest) -> Void,
policyDecisionHandler: @escaping (Bool) -> Void) -> Bool {
let destinationURL = destinationRequest.url
if let mainFrameUrl = mainFrameUrl, destinationURL != mainFrameUrl {
// If mainFrameUrl is set and is different from destinationURL we will assume this is a redirect
// We do not rewrite redirects due to breakage concerns
Expand All @@ -91,7 +108,7 @@ public struct LinkProtection {
var didRewriteLink = false
if let newURL = linkCleaner.extractCanonicalFromAMPLink(initiator: initiatingURL, destination: destinationURL) {
policyDecisionHandler(false)
onLinkRewrite(newURL)
onLinkRewrite(makeNewRequest(changingUrl: newURL, inRequest: destinationRequest))
didRewriteLink = true
} else if ampExtractor.urlContainsAMPKeyword(destinationURL) {
onStartExtracting()
Expand All @@ -103,13 +120,13 @@ public struct LinkProtection {
}

policyDecisionHandler(false)
onLinkRewrite(canonical)
onLinkRewrite(makeNewRequest(changingUrl: canonical, inRequest: destinationRequest))
}
didRewriteLink = true
} else if let newURL = linkCleaner.cleanTrackingParameters(initiator: initiatingURL, url: destinationURL) {
if newURL != destinationURL {
policyDecisionHandler(false)
onLinkRewrite(newURL)
onLinkRewrite(makeNewRequest(changingUrl: newURL, inRequest: destinationRequest))
didRewriteLink = true
}
}
Expand All @@ -121,10 +138,10 @@ public struct LinkProtection {
navigationAction: WKNavigationAction,
onStartExtracting: () -> Void,
onFinishExtracting: @escaping () -> Void,
onLinkRewrite: @escaping (URL, WKNavigationAction) -> Void,
onLinkRewrite: @escaping (URLRequest, WKNavigationAction) -> Void,
policyDecisionHandler: @escaping (WKNavigationActionPolicy) -> Void) -> Bool {
requestTrackingLinkRewrite(initiatingURL: initiatingURL,
destinationURL: navigationAction.request.url!,
destinationRequest: navigationAction.request,
onStartExtracting: onStartExtracting,
onFinishExtracting: onFinishExtracting,
onLinkRewrite: { onLinkRewrite($0, navigationAction) },
Expand All @@ -134,13 +151,13 @@ public struct LinkProtection {

@MainActor
public func requestTrackingLinkRewrite(initiatingURL: URL?,
destinationURL: URL,
destinationRequest: URLRequest,
onStartExtracting: () -> Void,
onFinishExtracting: @escaping () -> Void,
onLinkRewrite: @escaping (URL) -> Void) async -> Bool? {
onLinkRewrite: @escaping (URLRequest) -> Void) async -> Bool? {
await withCheckedContinuation { continuation in
let didRewriteLink = requestTrackingLinkRewrite(initiatingURL: initiatingURL,
destinationURL: destinationURL,
destinationRequest: destinationRequest,
onStartExtracting: onStartExtracting,
onFinishExtracting: onFinishExtracting,
onLinkRewrite: onLinkRewrite) { navigationActionPolicy in
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import Network
import Common
import Combine

final public class NetworkProtectionLatencyMonitor {
public actor NetworkProtectionLatencyMonitor {
public enum ConnectionQuality: String {
case terrible
case poor
Expand Down Expand Up @@ -55,196 +55,103 @@ final public class NetworkProtectionLatencyMonitor {

private static let reportThreshold: TimeInterval = .minutes(10)
private static let measurementInterval: TimeInterval = .seconds(5)
private static let pingTimeout: TimeInterval = 0.3
private static let pingTimeout: TimeInterval = .seconds(1)

private static let unknownLatency: TimeInterval = -1

public var publisher: AnyPublisher<Result, Never> {
subject.eraseToAnyPublisher()
}
private let subject = PassthroughSubject<Result, Never>()

private let latencySubject = PassthroughSubject<TimeInterval, Never>()
private var latencyCancellable: AnyCancellable?

private actor TimerRunCoordinator {
private(set) var isRunning = false

func start() {
isRunning = true
}
private var latencyCancellable: AnyCancellable?

func stop() {
isRunning = false
private var task: Task<Never, Error>? {
willSet {
task?.cancel()
}
}

private var timer: DispatchSourceTimer?
private let timerRunCoordinator = TimerRunCoordinator()
private let timerQueue: DispatchQueue

private let lock = NSLock()

private var _lastLatencyReported: Date = .distantPast
private(set) var lastLatencyReported: Date {
get {
lock.lock(); defer { lock.unlock() }
return _lastLatencyReported
}
set {
lock.lock()
self._lastLatencyReported = newValue
lock.unlock()
}
var isStarted: Bool {
task?.isCancelled == false
}

private let serverIP: () -> IPv4Address?

private let log: OSLog

private var _ignoreThreshold = false
private(set) var ignoreThreshold: Bool {
get {
lock.lock(); defer { lock.unlock() }
return _ignoreThreshold
}
set {
lock.lock()
self._ignoreThreshold = newValue
lock.unlock()
}
}
private var lastLatencyReported: Date = .distantPast

// MARK: - Init & deinit

init(serverIP: @escaping () -> IPv4Address?, timerQueue: DispatchQueue, log: OSLog) {
self.serverIP = serverIP
self.timerQueue = timerQueue
self.log = log

init() {
os_log("[+] %{public}@", log: .networkProtectionMemoryLog, type: .debug, String(describing: self))
}

deinit {
os_log("[-] %{public}@", log: .networkProtectionMemoryLog, type: .debug, String(describing: self))
task?.cancel()

cancelTimerImmediately()
os_log("[-] %{public}@", log: .networkProtectionMemoryLog, type: .debug, String(describing: self))
}

// MARK: - Start/Stop monitoring

public func start() async throws {
guard await !timerRunCoordinator.isRunning else {
os_log("Will not start the latency monitor as it's already running", log: log)
return
}

os_log("⚫️ Starting latency monitor", log: log)
public func start(serverIP: IPv4Address, callback: @escaping (Result) -> Void) {
os_log("⚫️ Starting latency monitor", log: .networkProtectionLatencyMonitorLog)

latencyCancellable = latencySubject.eraseToAnyPublisher()
.scan(ExponentialGeometricAverage()) { [weak self] measurements, latency in
.receive(on: DispatchQueue.main)
.scan(ExponentialGeometricAverage()) { measurements, latency in
if latency >= 0 {
measurements.addMeasurement(latency)
os_log("⚫️ Latency: %{public}f milliseconds", log: .networkProtectionPixel, type: .debug, latency)
os_log("⚫️ Latency: %{public}f milliseconds", log: .networkProtectionLatencyMonitorLog, type: .debug, latency)
} else {
self?.subject.send(.error)
callback(.error)
}

os_log("⚫️ Average: %{public}f milliseconds", log: .networkProtectionPixel, type: .debug, measurements.average)
os_log("⚫️ Average: %{public}f milliseconds", log: .networkProtectionLatencyMonitorLog, type: .debug, measurements.average)

return measurements
}
.map { ConnectionQuality(average: $0.average) }
.sink { [weak self] quality in
let now = Date()
if let self,
(now.timeIntervalSince1970 - self.lastLatencyReported.timeIntervalSince1970 >= Self.reportThreshold) || ignoreThreshold {
self.subject.send(.quality(quality))
self.lastLatencyReported = now
.sink { quality in
Task { [weak self] in
let now = Date()
if let self,
await now.timeIntervalSince1970 - self.lastLatencyReported.timeIntervalSince1970 >= Self.reportThreshold {
callback(.quality(quality))
await self.updateLastLatencyReported(date: now)
}
}
}

do {
try await scheduleTimer()
} catch {
os_log("⚫️ Stopping latency monitor prematurely", log: log)
throw error
task = Task.periodic(interval: Self.measurementInterval) { [weak self] in
await self?.measureLatency(to: serverIP)
}
}

public func stop() async {
os_log("⚫️ Stopping latency monitor", log: log)
await stopScheduledTimer()
}

// MARK: - Timer scheduling

private func scheduleTimer() async throws {
await stopScheduledTimer()

await timerRunCoordinator.start()

let timer = DispatchSource.makeTimerSource(queue: timerQueue)
self.timer = timer

timer.schedule(deadline: .now() + Self.measurementInterval, repeating: Self.measurementInterval)
timer.setEventHandler { [weak self] in
guard let self else { return }

Task {
await self.measureLatency()
}
}

timer.setCancelHandler { [weak self] in
self?.timer = nil
}
public func stop() {
os_log("⚫️ Stopping latency monitor", log: .networkProtectionLatencyMonitorLog)

timer.resume()
latencyCancellable = nil
task = nil
}

private func stopScheduledTimer() async {
await timerRunCoordinator.stop()

cancelTimerImmediately()
}

private func cancelTimerImmediately() {
guard let timer else { return }

if !timer.isCancelled {
timer.cancel()
}

self.timer = nil
private func updateLastLatencyReported(date: Date) {
lastLatencyReported = date
}

// MARK: - Latency monitor

@MainActor
public func measureLatency() async {
guard let serverIP = serverIP() else {
latencySubject.send(Self.unknownLatency)
return
}

os_log("⚫️ Pinging %{public}s", log: .networkProtectionPixel, type: .debug, serverIP.debugDescription)
private func measureLatency(to ip: IPv4Address) async {
os_log("⚫️ Pinging %{public}s", log: .networkProtectionLatencyMonitorLog, type: .debug, ip.debugDescription)

let result = await Pinger(ip: serverIP, timeout: Self.pingTimeout, log: .networkProtectionPixel).ping()
let result = await Pinger(ip: ip, timeout: Self.pingTimeout, log: .networkProtectionLatencyMonitorLog).ping()

switch result {
case .success(let pingResult):
latencySubject.send(pingResult.time * 1000)
case .failure(let error):
os_log("⚫️ Ping error: %{public}s", log: .networkProtectionPixel, type: .debug, error.localizedDescription)
os_log("⚫️ Ping error: %{public}s", log: .networkProtectionLatencyMonitorLog, type: .debug, error.localizedDescription)
latencySubject.send(Self.unknownLatency)
}
}

public func simulateLatency(_ timeInterval: TimeInterval) {
ignoreThreshold = true
latencySubject.send(timeInterval)
ignoreThreshold = false
}
}

Expand Down
Loading

0 comments on commit 8dd2ee0

Please sign in to comment.