Skip to content

Commit

Permalink
Implement server selection resolver
Browse files Browse the repository at this point in the history
Co-Authored-By: Graeme Arthur <[email protected]>
  • Loading branch information
quanganhdo and graeme committed Aug 21, 2024
1 parent 6f36c07 commit cb3efb0
Show file tree
Hide file tree
Showing 9 changed files with 366 additions and 8 deletions.
109 changes: 109 additions & 0 deletions Sources/NetworkProtection/Models/VPNServerSelectionResolver.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
//
// VPNServerSelectionResolver.swift
//
// Copyright © 2024 DuckDuckGo. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//

import Foundation

enum VPNServerSelectionResolverError: Error {
case countryNotFound
case fetchingLocationsFailed(Error)
}

protocol VPNServerSelectionResolving {
func resolvedServerSelectionMethod() async -> NetworkProtectionServerSelectionMethod
}

final class VPNServerSelectionResolver: VPNServerSelectionResolving {
private let locationListRepository: NetworkProtectionLocationListRepository
private let vpnSettings: VPNSettings

init(locationListRepository: NetworkProtectionLocationListRepository, vpnSettings: VPNSettings) {
self.locationListRepository = locationListRepository
self.vpnSettings = vpnSettings
}

/// Address the case where the prefered location becomes unavailable
/// We fall back to the country, if a city isn't available,
/// or nearest if the country isn't available
public func resolvedServerSelectionMethod() async -> NetworkProtectionServerSelectionMethod {
switch currentServerSelectionMethod {
case .automatic, .preferredServer, .avoidServer, .failureRecovery:
return currentServerSelectionMethod
case .preferredLocation(let networkProtectionSelectedLocation):
do {
let location = try await resolveSelectionAgainstAvailableLocations(networkProtectionSelectedLocation)
return .preferredLocation(location)
} catch let error as VPNServerSelectionResolverError {
switch error {
case .countryNotFound:
return .automatic
case .fetchingLocationsFailed:
return currentServerSelectionMethod
}
} catch {
return currentServerSelectionMethod
}
}
}

private func resolveSelectionAgainstAvailableLocations(_ selection: NetworkProtectionSelectedLocation) async throws -> NetworkProtectionSelectedLocation {
let availableLocations: [NetworkProtectionLocation]
do {
availableLocations = try await locationListRepository.fetchLocationList(cachePolicy: .ignoreCache)
} catch {
throw VPNServerSelectionResolverError.fetchingLocationsFailed(error)
}

let availableCitySelections = availableLocations.flatMap { location in
location.cities.map { city in NetworkProtectionSelectedLocation(country: location.country, city: city.name) }
}

if availableCitySelections.contains(selection) {
return selection
}

let selectedCountry = NetworkProtectionSelectedLocation(country: selection.country)
let availableCountrySelections = availableLocations.map { NetworkProtectionSelectedLocation(country: $0.country) }
guard availableCountrySelections.contains(selectedCountry) else {
throw VPNServerSelectionResolverError.countryNotFound
}

return selectedCountry
}

private var currentServerSelectionMethod: NetworkProtectionServerSelectionMethod {
var serverSelectionMethod: NetworkProtectionServerSelectionMethod

switch vpnSettings.selectedLocation {
case .nearest:
serverSelectionMethod = .automatic
case .location(let networkProtectionSelectedLocation):
serverSelectionMethod = .preferredLocation(networkProtectionSelectedLocation)
}

switch vpnSettings.selectedServer {
case .automatic:
break
case .endpoint(let string):
// Selecting a specific server will override locations setting
// Only available in debug
serverSelectionMethod = .preferredServer(serverName: string)
}

return serverSelectionMethod
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ public enum NetworkProtectionDNSSettings: Codable, Equatable, CustomStringConver
public protocol NetworkProtectionDeviceManagement {
typealias GenerateTunnelConfigurationResult = (tunnelConfiguration: TunnelConfiguration, server: NetworkProtectionServer)

func generateTunnelConfiguration(selectionMethod: NetworkProtectionServerSelectionMethod,
func generateTunnelConfiguration(resolvedSelectionMethod: NetworkProtectionServerSelectionMethod,
includedRoutes: [IPAddressRange],
excludedRoutes: [IPAddressRange],
dnsSettings: NetworkProtectionDNSSettings,
Expand Down Expand Up @@ -133,7 +133,7 @@ public actor NetworkProtectionDeviceManager: NetworkProtectionDeviceManagement {
/// 2. If the key is new, register it with all backend servers and return a tunnel configuration + its server info
/// 3. If the key already existed, look up the stored set of backend servers and check if the preferred server is registered. If not, register it, and return the tunnel configuration + server info.
///
public func generateTunnelConfiguration(selectionMethod: NetworkProtectionServerSelectionMethod,
public func generateTunnelConfiguration(resolvedSelectionMethod: NetworkProtectionServerSelectionMethod,
includedRoutes: [IPAddressRange],
excludedRoutes: [IPAddressRange],
dnsSettings: NetworkProtectionDNSSettings,
Expand All @@ -156,7 +156,7 @@ public actor NetworkProtectionDeviceManager: NetworkProtectionDeviceManagement {
}
}

let (selectedServer, newExpiration) = try await register(keyPair: keyPair, selectionMethod: selectionMethod)
let (selectedServer, newExpiration) = try await register(keyPair: keyPair, selectionMethod: resolvedSelectionMethod)
os_log("Server registration successul", log: .networkProtection)

keyStore.updateKeyPair(keyPair)
Expand Down
13 changes: 12 additions & 1 deletion Sources/NetworkProtection/PacketTunnelProvider.swift
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,16 @@ open class PacketTunnelProvider: NEPacketTunnelProvider {

// MARK: - Server Selection

private lazy var serverSelectionResolver: VPNServerSelectionResolving = {
let locationRepository = NetworkProtectionLocationListCompositeRepository(
environment: settings.selectedEnvironment,
tokenStore: tokenStore,
errorEvents: debugEvents,
isSubscriptionEnabled: isSubscriptionEnabled
)
return VPNServerSelectionResolver(locationListRepository: locationRepository, vpnSettings: settings)
}()

@MainActor
private var lastSelectedServer: NetworkProtectionServer? {
didSet {
Expand Down Expand Up @@ -980,10 +990,11 @@ open class PacketTunnelProvider: NEPacketTunnelProvider {
regenerateKey: Bool) async throws -> TunnelConfiguration {

let configurationResult: NetworkProtectionDeviceManager.GenerateTunnelConfigurationResult
let resolvedServerSelectionMethod = await serverSelectionResolver.resolvedServerSelectionMethod()

do {
configurationResult = try await deviceManager.generateTunnelConfiguration(
selectionMethod: serverSelectionMethod,
resolvedSelectionMethod: resolvedServerSelectionMethod,
includedRoutes: includedRoutes,
excludedRoutes: excludedRoutes,
dnsSettings: dnsSettings,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ actor FailureRecoveryHandler: FailureRecoveryHandling {
let configurationResult: NetworkProtectionDeviceManagement.GenerateTunnelConfigurationResult

configurationResult = try await deviceManager.generateTunnelConfiguration(
selectionMethod: serverSelectionMethod,
resolvedSelectionMethod: serverSelectionMethod,
includedRoutes: includedRoutes,
excludedRoutes: excludedRoutes,
dnsSettings: dnsSettings,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,16 @@
import Foundation
import Common

public enum NetworkProtectionLocationListCachePolicy {
case returnCacheElseLoad
case ignoreCache

static var `default` = NetworkProtectionLocationListCachePolicy.returnCacheElseLoad
}

public protocol NetworkProtectionLocationListRepository {
func fetchLocationList() async throws -> [NetworkProtectionLocation]
func fetchLocationList(cachePolicy: NetworkProtectionLocationListCachePolicy) async throws -> [NetworkProtectionLocation]
}

final public class NetworkProtectionLocationListCompositeRepository: NetworkProtectionLocationListRepository {
Expand Down Expand Up @@ -54,12 +62,35 @@ final public class NetworkProtectionLocationListCompositeRepository: NetworkProt
self.isSubscriptionEnabled = isSubscriptionEnabled
}

@MainActor
@discardableResult
public func fetchLocationList(cachePolicy: NetworkProtectionLocationListCachePolicy) async throws -> [NetworkProtectionLocation] {
switch cachePolicy {
case .returnCacheElseLoad:
return try await fetchLocationList()
case .ignoreCache:
return try await fetchLocationListFromRemote()
}
}

@MainActor
@discardableResult
public func fetchLocationList() async throws -> [NetworkProtectionLocation] {
try await fetchLocationListReturningCacheElseLoad()
}

@MainActor
@discardableResult
func fetchLocationListReturningCacheElseLoad() async throws -> [NetworkProtectionLocation] {
guard !canUseCache else {
return Self.locationList
}
return try await fetchLocationListFromRemote()
}

@MainActor
@discardableResult
func fetchLocationListFromRemote() async throws -> [NetworkProtectionLocation] {
do {
guard let authToken = try tokenStore.fetchToken() else {
throw NetworkProtectionError.noAuthTokenFound
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,14 +43,14 @@ public final class MockNetworkProtectionDeviceManagement: NetworkProtectionDevic
public init() {}

public func generateTunnelConfiguration(
selectionMethod: NetworkProtection.NetworkProtectionServerSelectionMethod,
resolvedSelectionMethod: NetworkProtection.NetworkProtectionServerSelectionMethod,
includedRoutes: [NetworkProtection.IPAddressRange],
excludedRoutes: [NetworkProtection.IPAddressRange],
dnsSettings: NetworkProtectionDNSSettings,
isKillSwitchEnabled: Bool,
regenerateKey: Bool) async throws -> (tunnelConfiguration: NetworkProtection.TunnelConfiguration, server: NetworkProtection.NetworkProtectionServer) {
spyGenerateTunnelConfiguration = (
selectionMethod: selectionMethod,
selectionMethod: resolvedSelectionMethod,
includedRoutes: includedRoutes,
excludedRoutes: excludedRoutes,
isKillSwitchEnabled: isKillSwitchEnabled,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
//
// MockNetworkProtectionLocationListRepository.swift
//
// Copyright © 2024 DuckDuckGo. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//

import Foundation
@testable import NetworkProtection

final class MockNetworkProtectionLocationListRepository: NetworkProtectionLocationListRepository {
var stubFetchLocationList: [NetworkProtectionLocation] = []
var stubFetchLocationListError: Error?
var spyIgnoreCache: Bool = false

func fetchLocationList() async throws -> [NetworkProtectionLocation] {
if let stubFetchLocationListError {
throw stubFetchLocationListError
}
return stubFetchLocationList
}

func fetchLocationList(cachePolicy: NetworkProtectionLocationListCachePolicy) async throws -> [NetworkProtectionLocation] {
switch cachePolicy {
case .returnCacheElseLoad:
return try await fetchLocationList()
case .ignoreCache:
return try await fetchLocationListIgnoringCache()
}
}

func fetchLocationListIgnoringCache() async throws -> [NetworkProtection.NetworkProtectionLocation] {
spyIgnoreCache = true
return try await fetchLocationList()
}
}

extension NetworkProtectionLocation {
static func testData(country: String = "", cities: [City] = []) -> NetworkProtectionLocation {
return Self(country: country, cities: cities)
}
}

extension NetworkProtectionLocation.City {
static func testData(name: String = "") -> NetworkProtectionLocation.City {
Self(name: name)
}
}
Loading

0 comments on commit cb3efb0

Please sign in to comment.