diff --git a/Sources/NetworkProtection/Models/VPNServerSelectionResolver.swift b/Sources/NetworkProtection/Models/VPNServerSelectionResolver.swift new file mode 100644 index 000000000..c4b3795a0 --- /dev/null +++ b/Sources/NetworkProtection/Models/VPNServerSelectionResolver.swift @@ -0,0 +1,106 @@ +// +// VPNServerSelectionResolver.swift +// +// Copyright © 2024 DuckDuckGo. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +import Foundation + +enum VPNServerSelectionResolverError: Error { + case countryNotFound + case fetchingLocationsFailed(Error) +} + +protocol VPNServerSelectionResolving { + func resolvedServerSelectionMethod() async -> NetworkProtectionServerSelectionMethod +} + +final class VPNServerSelectionResolver: VPNServerSelectionResolving { + private let locationListRepository: NetworkProtectionLocationListRepository + private let vpnSettings: VPNSettings + + init(locationListRepository: NetworkProtectionLocationListRepository, vpnSettings: VPNSettings) { + self.locationListRepository = locationListRepository + self.vpnSettings = vpnSettings + } + + public func resolvedServerSelectionMethod() async -> NetworkProtectionServerSelectionMethod { + switch currentServerSelectionMethod { + case .automatic, .preferredServer, .avoidServer, .failureRecovery: + return currentServerSelectionMethod + case .preferredLocation(let networkProtectionSelectedLocation): + do { + let location = try await resolveSelectionAgainstAvailableLocations(networkProtectionSelectedLocation) + return .preferredLocation(location) + } catch let error as VPNServerSelectionResolverError { + switch error { + case .countryNotFound: + return .automatic + case .fetchingLocationsFailed: + return currentServerSelectionMethod + } + } catch { + return currentServerSelectionMethod + } + } + } + + private func resolveSelectionAgainstAvailableLocations(_ selection: NetworkProtectionSelectedLocation) async throws -> NetworkProtectionSelectedLocation { + let availableLocations: [NetworkProtectionLocation] + do { + availableLocations = try await locationListRepository.fetchLocationListIgnoringCache() + } catch { + throw VPNServerSelectionResolverError.fetchingLocationsFailed(error) + } + + let availableCitySelections = availableLocations.flatMap { location in + location.cities.map { city in NetworkProtectionSelectedLocation(country: location.country, city: city.name) } + } + + if availableCitySelections.contains(selection) { + return selection + } + + let selectedCountry = NetworkProtectionSelectedLocation(country: selection.country) + let availableCountrySelections = availableLocations.map { NetworkProtectionSelectedLocation(country: $0.country) } + guard availableCountrySelections.contains(selectedCountry) else { + throw VPNServerSelectionResolverError.countryNotFound + } + + return selectedCountry + } + + private var currentServerSelectionMethod: NetworkProtectionServerSelectionMethod { + var serverSelectionMethod: NetworkProtectionServerSelectionMethod + + switch vpnSettings.selectedLocation { + case .nearest: + serverSelectionMethod = .automatic + case .location(let networkProtectionSelectedLocation): + serverSelectionMethod = .preferredLocation(networkProtectionSelectedLocation) + } + + switch vpnSettings.selectedServer { + case .automatic: + break + case .endpoint(let string): + // Selecting a specific server will override locations setting + // Only available in debug + serverSelectionMethod = .preferredServer(serverName: string) + } + + return serverSelectionMethod + } +} diff --git a/Sources/NetworkProtection/PacketTunnelProvider.swift b/Sources/NetworkProtection/PacketTunnelProvider.swift index 533cdb2b9..1894f0715 100644 --- a/Sources/NetworkProtection/PacketTunnelProvider.swift +++ b/Sources/NetworkProtection/PacketTunnelProvider.swift @@ -172,6 +172,16 @@ open class PacketTunnelProvider: NEPacketTunnelProvider { // MARK: - Server Selection + private lazy var serverSelection: VPNServerSelectionResolving = { + let locationRepository = NetworkProtectionLocationListCompositeRepository( + environment: settings.selectedEnvironment, + tokenStore: tokenStore, + errorEvents: debugEvents, + isSubscriptionEnabled: isSubscriptionEnabled + ) + return VPNServerSelectionResolver(locationListRepository: locationRepository, vpnSettings: settings) + }() + @MainActor private var lastSelectedServer: NetworkProtectionServer? { didSet { @@ -343,7 +353,7 @@ open class PacketTunnelProvider: NEPacketTunnelProvider { // MARK: - Initializers private let keychainType: KeychainType - private let debugEvents: EventMapping? + private let debugEvents: EventMapping private let providerEvents: EventMapping public let isSubscriptionEnabled: Bool @@ -354,7 +364,7 @@ open class PacketTunnelProvider: NEPacketTunnelProvider { controllerErrorStore: NetworkProtectionTunnelErrorStore, keychainType: KeychainType, tokenStore: NetworkProtectionTokenStore, - debugEvents: EventMapping?, + debugEvents: EventMapping, providerEvents: EventMapping, settings: VPNSettings, defaults: UserDefaults, @@ -602,36 +612,12 @@ open class PacketTunnelProvider: NEPacketTunnelProvider { completionHandler: completionHandler) } - var currentServerSelectionMethod: NetworkProtectionServerSelectionMethod { - var serverSelectionMethod: NetworkProtectionServerSelectionMethod - - switch settings.selectedLocation { - case .nearest: - serverSelectionMethod = .automatic - case .location(let networkProtectionSelectedLocation): - serverSelectionMethod = .preferredLocation(networkProtectionSelectedLocation) - } - - switch settings.selectedServer { - case .automatic: - break - case .endpoint(let string): - // Selecting a specific server will override locations setting - // Only available in debug - serverSelectionMethod = .preferredServer(serverName: string) - } - - return serverSelectionMethod - } - private func startTunnel(onDemand: Bool, completionHandler: @escaping (Error?) -> Void) { Task { do { os_log("🔵 Generating tunnel config", log: .networkProtection, type: .info) os_log("🔵 Excluded ranges are: %{public}@", log: .networkProtection, type: .info, String(describing: settings.excludedRanges)) - os_log("🔵 Server selection method: %{public}@", log: .networkProtection, type: .info, currentServerSelectionMethod.debugDescription) - let tunnelConfiguration = try await generateTunnelConfiguration(serverSelectionMethod: currentServerSelectionMethod, - includedRoutes: includedRoutes ?? [], + let tunnelConfiguration = try await generateTunnelConfiguration(includedRoutes: includedRoutes ?? [], excludedRoutes: settings.excludedRanges, regenerateKey: true) startTunnel(with: tunnelConfiguration, onDemand: onDemand, completionHandler: completionHandler) @@ -651,7 +637,7 @@ open class PacketTunnelProvider: NEPacketTunnelProvider { adapter.start(tunnelConfiguration: tunnelConfiguration) { [weak self] error in if let error { os_log("🔵 Starting tunnel failed with %{public}@", log: .networkProtection, type: .error, error.localizedDescription) - self?.debugEvents?.fire(error.networkProtectionError) + self?.debugEvents.fire(error.networkProtectionError) completionHandler(error) return } @@ -745,7 +731,7 @@ open class PacketTunnelProvider: NEPacketTunnelProvider { if let error { os_log("🔵 Error while stopping adapter: %{public}@", log: .networkProtection, type: .error, error.localizedDescription) - self?.debugEvents?.fire(error.networkProtectionError) + self?.debugEvents.fire(error.networkProtectionError) continuation.resume(throwing: error) return @@ -775,17 +761,7 @@ open class PacketTunnelProvider: NEPacketTunnelProvider { // MARK: - Tunnel Configuration @MainActor - public func updateTunnelConfiguration(reassert: Bool, regenerateKey: Bool = false) async throws { - try await updateTunnelConfiguration( - serverSelectionMethod: currentServerSelectionMethod, - reassert: reassert, - regenerateKey: regenerateKey - ) - } - - @MainActor - public func updateTunnelConfiguration(serverSelectionMethod: NetworkProtectionServerSelectionMethod, - reassert: Bool, + public func updateTunnelConfiguration(reassert: Bool, regenerateKey: Bool = false) async throws { providerEvents.fire(.tunnelUpdateAttempt(.begin)) @@ -796,10 +772,9 @@ open class PacketTunnelProvider: NEPacketTunnelProvider { let tunnelConfiguration: TunnelConfiguration do { - tunnelConfiguration = try await generateTunnelConfiguration(serverSelectionMethod: serverSelectionMethod, - includedRoutes: includedRoutes ?? [], - excludedRoutes: settings.excludedRanges, - regenerateKey: regenerateKey) + tunnelConfiguration = try await generateTunnelConfiguration(includedRoutes: includedRoutes ?? [], + excludedRoutes: settings.excludedRanges, + regenerateKey: regenerateKey) } catch { providerEvents.fire(.tunnelUpdateAttempt(.failure(error))) throw error @@ -819,7 +794,7 @@ open class PacketTunnelProvider: NEPacketTunnelProvider { self.adapter.update(tunnelConfiguration: tunnelConfiguration, reassert: reassert) { [weak self] error in if let error = error { os_log("🔵 Failed to update the configuration: %{public}@", type: .error, error.localizedDescription) - self?.debugEvents?.fire(error.networkProtectionError) + self?.debugEvents.fire(error.networkProtectionError) continuation.resume(throwing: error) return } @@ -846,14 +821,14 @@ open class PacketTunnelProvider: NEPacketTunnelProvider { } @MainActor - private func generateTunnelConfiguration(serverSelectionMethod: NetworkProtectionServerSelectionMethod, - includedRoutes: [IPAddressRange], + private func generateTunnelConfiguration(includedRoutes: [IPAddressRange], excludedRoutes: [IPAddressRange], regenerateKey: Bool) async throws -> TunnelConfiguration { let configurationResult: NetworkProtectionDeviceManager.GenerateTunnelConfigurationResult do { + let serverSelectionMethod = await serverSelection.resolvedServerSelectionMethod() configurationResult = try await deviceManager.generateTunnelConfiguration( selectionMethod: serverSelectionMethod, includedRoutes: includedRoutes, @@ -957,7 +932,6 @@ open class PacketTunnelProvider: NEPacketTunnelProvider { settings.apply(change: change) } - // swiftlint:disable:next cyclomatic_complexity private func handleSettingsChange(_ change: VPNSettings.Change, completionHandler: ((Data?) -> Void)? = nil) { switch change { case .setExcludeLocalNetworks: @@ -967,35 +941,10 @@ open class PacketTunnelProvider: NEPacketTunnelProvider { } completionHandler?(nil) } - case .setSelectedServer(let selectedServer): - let serverSelectionMethod: NetworkProtectionServerSelectionMethod - - switch selectedServer { - case .automatic: - serverSelectionMethod = .automatic - case .endpoint(let serverName): - serverSelectionMethod = .preferredServer(serverName: serverName) - } - - Task { @MainActor in - if case .connected = connectionStatus { - try? await updateTunnelConfiguration(serverSelectionMethod: serverSelectionMethod, reassert: true) - } - completionHandler?(nil) - } - case .setSelectedLocation(let selectedLocation): - let serverSelectionMethod: NetworkProtectionServerSelectionMethod - - switch selectedLocation { - case .nearest: - serverSelectionMethod = .automatic - case .location(let location): - serverSelectionMethod = .preferredLocation(location) - } - + case .setSelectedServer, .setSelectedLocation: Task { @MainActor in if case .connected = connectionStatus { - try? await updateTunnelConfiguration(serverSelectionMethod: serverSelectionMethod, reassert: true) + try? await updateTunnelConfiguration(reassert: true) } completionHandler?(nil) } @@ -1093,7 +1042,7 @@ open class PacketTunnelProvider: NEPacketTunnelProvider { settings.selectedServer = .endpoint(serverName) if case .connected = connectionStatus { - try? await updateTunnelConfiguration(serverSelectionMethod: .preferredServer(serverName: serverName), reassert: true) + try? await updateTunnelConfiguration(reassert: true) } completionHandler?(nil) } @@ -1174,7 +1123,7 @@ open class PacketTunnelProvider: NEPacketTunnelProvider { adapter.stop { [weak self] error in if let error { - self?.debugEvents?.fire(error.networkProtectionError) + self?.debugEvents.fire(error.networkProtectionError) os_log("🔵 Failed to stop WireGuard adapter: %{public}@", log: .networkProtection, type: .info, error.localizedDescription) } diff --git a/Sources/NetworkProtection/Repositories/NetworkProtectionLocationListRepository.swift b/Sources/NetworkProtection/Repositories/NetworkProtectionLocationListRepository.swift index be741b8db..2341de78e 100644 --- a/Sources/NetworkProtection/Repositories/NetworkProtectionLocationListRepository.swift +++ b/Sources/NetworkProtection/Repositories/NetworkProtectionLocationListRepository.swift @@ -21,9 +21,11 @@ import Common public protocol NetworkProtectionLocationListRepository { func fetchLocationList() async throws -> [NetworkProtectionLocation] + func fetchLocationListIgnoringCache() async throws -> [NetworkProtectionLocation] } final public class NetworkProtectionLocationListCompositeRepository: NetworkProtectionLocationListRepository { + @MainActor private static var locationList: [NetworkProtectionLocation] = [] @MainActor private static var cacheTimestamp = Date() private static let cacheValidity = TimeInterval(60) // Refreshes at most once per minute @@ -60,6 +62,11 @@ final public class NetworkProtectionLocationListCompositeRepository: NetworkProt guard !canUseCache else { return Self.locationList } + return try await fetchLocationListIgnoringCache() + } + + @MainActor + public func fetchLocationListIgnoringCache() async throws -> [NetworkProtectionLocation] { do { guard let authToken = try tokenStore.fetchToken() else { throw NetworkProtectionError.noAuthTokenFound diff --git a/Sources/NetworkProtectionTestUtils/MockNetworkProtectionDeviceManagement.swift b/Sources/NetworkProtectionTestUtils/MockNetworkProtectionDeviceManagement.swift index 8c9e30e17..952124ebb 100644 --- a/Sources/NetworkProtectionTestUtils/MockNetworkProtectionDeviceManagement.swift +++ b/Sources/NetworkProtectionTestUtils/MockNetworkProtectionDeviceManagement.swift @@ -20,6 +20,7 @@ import Foundation import NetworkProtection public final class MockNetworkProtectionDeviceManagement: NetworkProtectionDeviceManagement { + enum MockError: Error { case noStubSet } @@ -42,12 +43,7 @@ public final class MockNetworkProtectionDeviceManagement: NetworkProtectionDevic public init() {} - public func generateTunnelConfiguration( - selectionMethod: NetworkProtection.NetworkProtectionServerSelectionMethod, - includedRoutes: [NetworkProtection.IPAddressRange], - excludedRoutes: [NetworkProtection.IPAddressRange], - isKillSwitchEnabled: Bool, - regenerateKey: Bool) async throws -> (tunnelConfiguration: NetworkProtection.TunnelConfiguration, server: NetworkProtection.NetworkProtectionServer) { + public func generateTunnelConfiguration(selectionMethod: NetworkProtection.NetworkProtectionServerSelectionMethod, includedRoutes: [NetworkProtection.IPAddressRange], excludedRoutes: [NetworkProtection.IPAddressRange], isKillSwitchEnabled: Bool, regenerateKey: Bool) async throws -> NetworkProtectionDeviceManagement.GenerateTunnelConfigurationResult { spyGenerateTunnelConfiguration = ( selectionMethod: selectionMethod, includedRoutes: includedRoutes, diff --git a/Sources/NetworkProtectionTestUtils/Repositories/MockNetworkProtectionLocationListRepository.swift b/Sources/NetworkProtectionTestUtils/Repositories/MockNetworkProtectionLocationListRepository.swift new file mode 100644 index 000000000..793e46a42 --- /dev/null +++ b/Sources/NetworkProtectionTestUtils/Repositories/MockNetworkProtectionLocationListRepository.swift @@ -0,0 +1,50 @@ +// +// MockNetworkProtectionLocationListRepository.swift +// +// Copyright © 2024 DuckDuckGo. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +import Foundation +@testable import NetworkProtection + +final class MockNetworkProtectionLocationListRepository: NetworkProtectionLocationListRepository { + var stubFetchLocationList: [NetworkProtectionLocation] = [] + var stubFetchLocationListError: Error? + var spyIgnoreCache: Bool = false + + func fetchLocationList() async throws -> [NetworkProtectionLocation] { + if let stubFetchLocationListError { + throw stubFetchLocationListError + } + return stubFetchLocationList + } + + func fetchLocationListIgnoringCache() async throws -> [NetworkProtection.NetworkProtectionLocation] { + spyIgnoreCache = true + return try await fetchLocationList() + } +} + +extension NetworkProtectionLocation { + static func testData(country: String = "", cities: [City] = []) -> NetworkProtectionLocation { + return Self(country: country, cities: cities) + } +} + +extension NetworkProtectionLocation.City { + static func testData(name: String = "") -> NetworkProtectionLocation.City { + Self(name: name) + } +} diff --git a/Tests/NetworkProtectionTests/Models/VPNServerSelectionResolverTests.swift b/Tests/NetworkProtectionTests/Models/VPNServerSelectionResolverTests.swift new file mode 100644 index 000000000..26cc73a8d --- /dev/null +++ b/Tests/NetworkProtectionTests/Models/VPNServerSelectionResolverTests.swift @@ -0,0 +1,148 @@ +// +// VPNServerSelectionResolverTests.swift +// +// Copyright © 2024 DuckDuckGo. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +import Foundation +import XCTest +@testable import NetworkProtection +@testable import NetworkProtectionTestUtils + +final class VPNServerSelectionResolverTests: XCTestCase { + var resolver: VPNServerSelectionResolver! + var vpnSettings: VPNSettings! + var locationListRepository: MockNetworkProtectionLocationListRepository! + + override func setUp() { + super.setUp() + locationListRepository = MockNetworkProtectionLocationListRepository() + vpnSettings = VPNSettings(defaults: UserDefaults(suiteName: self.className)!) + resolver = VPNServerSelectionResolver( + locationListRepository: locationListRepository, + vpnSettings: vpnSettings + ) + } + + override func tearDown() { + vpnSettings.resetToDefaults() + vpnSettings = nil + resolver = nil + locationListRepository = nil + super.tearDown() + } + + func testResolvedServerSelectionMethod_selectedServer_returnsPreferredServer() async { + let serverName = "serverName" + vpnSettings.selectedServer = .endpoint(serverName) + let result = await resolver.resolvedServerSelectionMethod() + guard case .preferredServer(let preferredServerName) = result else { + XCTFail("Expected preferredServer method") + return + } + XCTAssertEqual(preferredServerName, serverName) + } + + func testResolvedServerSelectionMethod_selectedLocationIsNearest_returnsAutomatic() async { + vpnSettings.selectedLocation = .nearest + let result = await resolver.resolvedServerSelectionMethod() + guard case .automatic = result else { + XCTFail("Expected automatic method") + return + } + } + + func testResolvedServerSelectionMethod_selectedLocationIsCity_fetchesListIgnoringCache() async { + vpnSettings.selectedLocation = .location(.init(country: "nl", city: "Rotterdam")) + _ = await resolver.resolvedServerSelectionMethod() + XCTAssertTrue(locationListRepository.spyIgnoreCache) + } + + func testResolvedServerSelectionMethod_selectedLocationIsCity_fetchedLocationsContainThatCity_returnsPreferredCity() async { + let selectedLocation = NetworkProtectionSelectedLocation(country: "nl", city: "Rotterdam") + vpnSettings.selectedLocation = .location(selectedLocation) + locationListRepository.stubFetchLocationList = [ + .testData(country: "us"), + .testData( + country: "nl", + cities: [.testData(name: "Rotterdam")] + ) + ] + let result = await resolver.resolvedServerSelectionMethod() + guard case .preferredLocation(let location) = result else { + XCTFail("Expected preferredLocation method") + return + } + XCTAssertEqual(location, selectedLocation) + } + + func testResolvedServerSelectionMethod_selectedLocationIsCity_fetchedLocationsContainThatCountry_butNotCity_returnsPreferredCountryWithNilCity() async { + let selectedLocation = NetworkProtectionSelectedLocation(country: "nl", city: nil) + vpnSettings.selectedLocation = .location(selectedLocation) + locationListRepository.stubFetchLocationList = [ + .testData(country: "us"), + .testData( + country: "nl", + cities: [.testData(name: "Amsterdam")] + ) + ] + let result = await resolver.resolvedServerSelectionMethod() + guard case .preferredLocation(let location) = result else { + XCTFail("Expected preferredLocation method") + return + } + XCTAssertEqual(location, selectedLocation) + } + + func testResolvedServerSelectionMethod_selectedLocationIsCity_fetchedLocationsDoesNotContainCountry_returnsAutomatic() async { + let selectedLocation = NetworkProtectionSelectedLocation(country: "nl", city: nil) + vpnSettings.selectedLocation = .location(selectedLocation) + locationListRepository.stubFetchLocationList = [ + .testData(country: "us") + ] + let result = await resolver.resolvedServerSelectionMethod() + guard case .automatic = result else { + XCTFail("Expected automatic method") + return + } + } + + func testResolvedServerSelectionMethod_overridesAllLocationSelectionMethods_returnsPreferredServer() async { + let cases: [VPNSettings.SelectedLocation] = [ + .location(.init(country: "nl", city: "Rotterdam")), + .location(.init(country: "us", city: nil)), + .nearest + ] + + for currentSelectedLocation in cases { + vpnSettings.selectedLocation = currentSelectedLocation + let selectedServerName = "selectedServer" + vpnSettings.selectedServer = .endpoint(selectedServerName) + locationListRepository.stubFetchLocationList = [ + .testData(country: "us"), + .testData( + country: "nl", + cities: [.testData(name: "Rotterdam")] + ) + ] + let result = await resolver.resolvedServerSelectionMethod() + guard case .preferredServer(let server) = result else { + XCTFail("Expected preferredServer method") + return + } + XCTAssertEqual(server, selectedServerName) + } + } +} diff --git a/Tests/NetworkProtectionTests/Repositories/NetworkProtectionLocationListCompositeRepositoryTests.swift b/Tests/NetworkProtectionTests/Repositories/NetworkProtectionLocationListCompositeRepositoryTests.swift index 964aa273e..8e6daf56b 100644 --- a/Tests/NetworkProtectionTests/Repositories/NetworkProtectionLocationListCompositeRepositoryTests.swift +++ b/Tests/NetworkProtectionTests/Repositories/NetworkProtectionLocationListCompositeRepositoryTests.swift @@ -140,15 +140,3 @@ class NetworkProtectionLocationListCompositeRepositoryTests: XCTestCase { XCTAssertTrue(didReceiveError) } } - -private extension NetworkProtectionLocation { - static func testData(country: String = "", cities: [City] = []) -> NetworkProtectionLocation { - return Self(country: country, cities: cities) - } -} - -private extension NetworkProtectionLocation.City { - static func testData(name: String = "") -> NetworkProtectionLocation.City { - Self(name: name) - } -}