diff --git a/Sources/NetworkProtection/PacketTunnelProvider.swift b/Sources/NetworkProtection/PacketTunnelProvider.swift index 7b5b77d03..a337968f7 100644 --- a/Sources/NetworkProtection/PacketTunnelProvider.swift +++ b/Sources/NetworkProtection/PacketTunnelProvider.swift @@ -34,7 +34,9 @@ open class PacketTunnelProvider: NEPacketTunnelProvider { case userBecameActive case reportConnectionAttempt(attempt: ConnectionAttempt) case tunnelStartAttempt(_ step: TunnelStartAttemptStep) + case tunnelStopAttempt(_ step: TunnelStopAttemptStep) case tunnelUpdateAttempt(_ step: TunnelUpdateAttemptStep) + case tunnelWakeAttempt(_ step: TunnelWakeAttemptStep) case reportTunnelFailure(result: NetworkProtectionTunnelFailureMonitor.Result) case reportLatency(result: NetworkProtectionLatencyMonitor.Result) case rekeyAttempt(_ step: RekeyAttemptStep) @@ -47,7 +49,9 @@ open class PacketTunnelProvider: NEPacketTunnelProvider { } public typealias TunnelStartAttemptStep = AttemptStep + public typealias TunnelStopAttemptStep = AttemptStep public typealias TunnelUpdateAttemptStep = AttemptStep + public typealias TunnelWakeAttemptStep = AttemptStep public typealias RekeyAttemptStep = AttemptStep public enum ConnectionAttempt { @@ -209,10 +213,10 @@ open class PacketTunnelProvider: NEPacketTunnelProvider { return } - await rekey() + try? await rekey() } - private func rekey() async { + private func rekey() async throws { providerEvents.fire(.userBecameActive) // Experimental option to disable rekeying. @@ -230,6 +234,7 @@ open class PacketTunnelProvider: NEPacketTunnelProvider { } catch { os_log("Rekey attempt failed. This is not an error if you're using debug Key Management options: %{public}@", log: .networkProtectionKeyManagement, type: .error, String(describing: error)) providerEvents.fire(.rekeyAttempt(.failure(error))) + throw error } } @@ -664,54 +669,70 @@ open class PacketTunnelProvider: NEPacketTunnelProvider { // MARK: - Tunnel Stop - open override func stopTunnel(with reason: NEProviderStopReason, completionHandler: @escaping () -> Void) { + @MainActor + open override func stopTunnel(with reason: NEProviderStopReason) async { + providerEvents.fire(.tunnelStopAttempt(.begin)) - Task { @MainActor in - await stopMonitors() + os_log("Stopping tunnel with reason %{public}@", log: .networkProtection, type: .info, String(describing: reason)) - connectionStatus = .disconnecting - os_log("Stopping tunnel with reason %{public}@", log: .networkProtection, type: .info, String(describing: reason)) + do { + try await stopTunnel() + providerEvents.fire(.tunnelStopAttempt(.success)) + } catch { + providerEvents.fire(.tunnelStopAttempt(.failure(error))) + } - adapter.stop { [weak self] error in - if let error { - os_log("🔵 Failed to stop WireGuard adapter: %{public}@", log: .networkProtection, type: .info, error.localizedDescription) - self?.debugEvents?.fire(error.networkProtectionError) - } + if case .superceded = reason { + self.notificationsPresenter.showSupersededNotification() + } + } - Task { [weak self] in - if let self { - self.handleAdapterStopped() + /// Do not cancel, directly... call this method so that the adapter and tester are stopped too. + @MainActor + private func cancelTunnel(with stopError: Error) async { + providerEvents.fire(.tunnelStopAttempt(.begin)) - if case .superceded = reason { - self.notificationsPresenter.showSupersededNotification() - } - } + os_log("Stopping tunnel with error %{public}@", log: .networkProtection, type: .error, stopError.localizedDescription) - completionHandler() - } - } + do { + try await stopTunnel() + providerEvents.fire(.tunnelStopAttempt(.success)) + } catch { + providerEvents.fire(.tunnelStopAttempt(.failure(error))) } + + cancelTunnelWithError(stopError) } - /// Do not cancel, directly... call this method so that the adapter and tester are stopped too. + // MARK: - Tunnel Stop: Support Methods + + /// Do not call this directly. Call `stopTunnel(with:)` or `cancelTunnel(with:)` instead. + /// @MainActor - private func cancelTunnel(with stopError: Error) async { + private func stopTunnel() async throws { + connectionStatus = .disconnecting await stopMonitors() + try await stopAdapter() + } - connectionStatus = .disconnecting + @MainActor + private func stopAdapter() async throws { + try await withCheckedThrowingContinuation { (continuation: CheckedContinuation) in + adapter.stop { [weak self] error in + if let self { + self.handleAdapterStopped() + } - os_log("Stopping tunnel with error %{public}@", log: .networkProtection, type: .info, stopError.localizedDescription) + if let error { + os_log("🔵 Error while stopping adapter: %{public}@", log: .networkProtection, type: .error, error.localizedDescription) + self?.debugEvents?.fire(error.networkProtectionError) - self.adapter.stop { [weak self] error in - guard let self else { return } + continuation.resume(throwing: error) + return + } - if let error = error { - os_log("Error while stopping adapter: %{public}@", log: .networkProtection, type: .info, error.localizedDescription) - debugEvents?.fire(error.networkProtectionError) + continuation.resume() } - - cancelTunnelWithError(stopError) - self.handleAdapterStopped() } } @@ -989,7 +1010,7 @@ open class PacketTunnelProvider: NEPacketTunnelProvider { private func handleExpireRegistrationKey(completionHandler: ((Data?) -> Void)? = nil) { Task { - await rekey() + try? await rekey() completionHandler?(nil) } } @@ -1181,7 +1202,7 @@ open class PacketTunnelProvider: NEPacketTunnelProvider { if !settings.disableRekeying { guard !isKeyExpired else { - await rekey() + try await rekey() return } } @@ -1281,7 +1302,7 @@ open class PacketTunnelProvider: NEPacketTunnelProvider { if #available(iOS 17, *) { handleShutDown() } else { - await rekey() + try? await rekey() } completion?() } @@ -1361,7 +1382,14 @@ open class PacketTunnelProvider: NEPacketTunnelProvider { os_log("Wake up", log: .networkProtectionSleepLog, type: .info) Task { - try? await handleAdapterStarted(startReason: .wake) + providerEvents.fire(.tunnelWakeAttempt(.begin)) + + do { + try await handleAdapterStarted(startReason: .wake) + providerEvents.fire(.tunnelWakeAttempt(.success)) + } catch { + providerEvents.fire(.tunnelWakeAttempt(.failure(error))) + } } } }