diff --git a/.swiftpm/xcode/xcshareddata/xcschemes/BrowserServicesKit-Package.xcscheme b/.swiftpm/xcode/xcshareddata/xcschemes/BrowserServicesKit-Package.xcscheme index 56a2ef845..7aeef5267 100644 --- a/.swiftpm/xcode/xcshareddata/xcschemes/BrowserServicesKit-Package.xcscheme +++ b/.swiftpm/xcode/xcshareddata/xcschemes/BrowserServicesKit-Package.xcscheme @@ -491,9 +491,9 @@ buildForAnalyzing = "YES"> @@ -539,6 +539,20 @@ ReferencedContainer = "container:"> + + + + @@ -818,6 +832,16 @@ ReferencedContainer = "container:"> + + + + Bool { Logger.contentBlocking.debug("Fetch last compiled rules: \(lastCompiledRules.count, privacy: .public)") let initialCompilationTask = LastCompiledRulesLookupTask(sourceRules: rulesSource.contentBlockerRulesLists, @@ -274,8 +281,10 @@ public class ContentBlockerRulesManager: CompiledRuleListsSource { // We want to confine Compilation work to WorkQueue, so we wait to come back from async Task mutex.wait() - if let rules = initialCompilationTask.getFetchedRules() { - applyRules(rules) + let rulesFound = initialCompilationTask.getFetchedRules() + + if let rulesFound { + applyRules(rulesFound) } else { lock.lock() state = .idle @@ -284,6 +293,8 @@ public class ContentBlockerRulesManager: CompiledRuleListsSource { // No matter if rules were found or not, we need to schedule recompilation, after all scheduleCompilation() + + return rulesFound != nil } private func prepareSourceManagers() { diff --git a/Sources/BrowserServicesKit/ContentBlocking/UserScripts/SurrogatesUserScript.swift b/Sources/BrowserServicesKit/ContentBlocking/UserScripts/SurrogatesUserScript.swift index ba10bf76f..0c83b7dd8 100644 --- a/Sources/BrowserServicesKit/ContentBlocking/UserScripts/SurrogatesUserScript.swift +++ b/Sources/BrowserServicesKit/ContentBlocking/UserScripts/SurrogatesUserScript.swift @@ -16,11 +16,11 @@ // limitations under the License. // -import WebKit +import Common +import ContentBlocking import TrackerRadarKit import UserScript -import ContentBlocking -import Common +@preconcurrency import WebKit public protocol SurrogatesUserScriptDelegate: NSObjectProtocol { diff --git a/Sources/BrowserServicesKit/ContentScopeScript/SpecialPagesUserScript.swift b/Sources/BrowserServicesKit/ContentScopeScript/SpecialPagesUserScript.swift index 215dbcc6f..f29a6e520 100644 --- a/Sources/BrowserServicesKit/ContentScopeScript/SpecialPagesUserScript.swift +++ b/Sources/BrowserServicesKit/ContentScopeScript/SpecialPagesUserScript.swift @@ -43,8 +43,7 @@ public final class SpecialPagesUserScript: NSObject, UserScript, UserScriptMessa @available(macOS 11.0, iOS 14.0, *) extension SpecialPagesUserScript: WKScriptMessageHandlerWithReply { @MainActor - public func userContentController(_ userContentController: WKUserContentController, - didReceive message: WKScriptMessage) async -> (Any?, String?) { + public func userContentController(_ userContentController: WKUserContentController, didReceive message: WKScriptMessage) async -> (Any?, String?) { let action = broker.messageHandlerFor(message) do { let json = try await broker.execute(action: action, original: message) diff --git a/Sources/BrowserServicesKit/ContentScopeScript/UserContentController.swift b/Sources/BrowserServicesKit/ContentScopeScript/UserContentController.swift index 9d534abe2..8c2ee2169 100644 --- a/Sources/BrowserServicesKit/ContentScopeScript/UserContentController.swift +++ b/Sources/BrowserServicesKit/ContentScopeScript/UserContentController.swift @@ -18,10 +18,10 @@ import Combine import Common -import UserScript -import WebKit -import QuartzCore import os.log +import QuartzCore +import UserScript +@preconcurrency import WebKit public protocol UserContentControllerDelegate: AnyObject { @MainActor diff --git a/Sources/BrowserServicesKit/FeatureFlagger/ExperimentCohortsManager.swift b/Sources/BrowserServicesKit/FeatureFlagger/ExperimentCohortsManager.swift new file mode 100644 index 000000000..60a03f8c4 --- /dev/null +++ b/Sources/BrowserServicesKit/FeatureFlagger/ExperimentCohortsManager.swift @@ -0,0 +1,144 @@ +// +// ExperimentCohortsManager.swift +// +// Copyright © 2024 DuckDuckGo. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +import Foundation + +public typealias CohortID = String +public typealias SubfeatureID = String +public typealias ParentFeatureID = String +public typealias Experiments = [String: ExperimentData] + +public struct ExperimentSubfeature { + let parentID: ParentFeatureID + let subfeatureID: SubfeatureID + let cohorts: [PrivacyConfigurationData.Cohort] +} + +public struct ExperimentData: Codable, Equatable { + public let parentID: ParentFeatureID + public let cohortID: CohortID + public let enrollmentDate: Date +} + +public protocol ExperimentCohortsManaging { + /// Retrieves all the experiments a user is enrolled in + var experiments: Experiments? { get } + + /// Resolves the cohort for a given experiment subfeature. + /// + /// This method determines whether the user is currently assigned to a valid cohort + /// for the specified experiment. If the assigned cohort is valid (i.e., it matches + /// one of the experiment's defined cohorts), the method returns the assigned cohort. + /// Otherwise, the invalid cohort is removed, and a new cohort is assigned if + /// `allowCohortReassignment` is `true`. + /// + /// - Parameters: + /// - experiment: The `ExperimentSubfeature` representing the experiment and its associated cohorts. + /// - allowCohortReassignment: A Boolean value indicating whether cohort assignment is allowed + /// if the user is not already assigned to a valid cohort. + /// + /// - Returns: The valid `CohortID` assigned to the user for the experiment, or `nil` + /// if no valid cohort exists and `allowCohortReassignment` is `false`. + /// + /// - Behavior: + /// 1. Retrieves the currently assigned cohort for the experiment using the `subfeatureID`. + /// 2. Validates if the assigned cohort exists within the experiment's cohort list: + /// - If valid, the assigned cohort is returned. + /// - If invalid, the cohort is removed from storage. + /// 3. If cohort assignment is enabled (`allowCohortReassignment` is `true`), a new cohort + /// is assigned based on the experiment's cohort weights and saved in storage. + /// - Cohort assignment is probabilistic, determined by the cohort weights. + /// + func resolveCohort(for experiment: ExperimentSubfeature, allowCohortReassignment: Bool) -> CohortID? +} + +public class ExperimentCohortsManager: ExperimentCohortsManaging { + + private var store: ExperimentsDataStoring + private let randomizer: (Range) -> Double + private let queue = DispatchQueue(label: "com.ExperimentCohortsManager.queue") + + public var experiments: Experiments? { + get { + queue.sync { + store.experiments + } + } + } + + public init(store: ExperimentsDataStoring = ExperimentsDataStore(), randomizer: @escaping (Range) -> Double = Double.random(in:)) { + self.store = store + self.randomizer = randomizer + } + + public func resolveCohort(for experiment: ExperimentSubfeature, allowCohortReassignment: Bool) -> CohortID? { + queue.sync { + let assignedCohort = cohort(for: experiment.subfeatureID) + if experiment.cohorts.contains(where: { $0.name == assignedCohort }) { + return assignedCohort + } + removeCohort(from: experiment.subfeatureID) + return allowCohortReassignment ? assignCohort(to: experiment) : nil + } + } +} + +// MARK: Helper functions +extension ExperimentCohortsManager { + + private func assignCohort(to subfeature: ExperimentSubfeature) -> CohortID? { + let cohorts = subfeature.cohorts + let totalWeight = cohorts.map(\.weight).reduce(0, +) + guard totalWeight > 0 else { return nil } + + let randomValue = randomizer(0.. CohortID? { + guard let experiments = store.experiments else { return nil } + return experiments[subfeatureID]?.cohortID + } + + private func enrollmentDate(for subfeatureID: SubfeatureID) -> Date? { + guard let experiments = store.experiments else { return nil } + return experiments[subfeatureID]?.enrollmentDate + } + + private func removeCohort(from subfeatureID: SubfeatureID) { + guard var experiments = store.experiments else { return } + experiments.removeValue(forKey: subfeatureID) + store.experiments = experiments + } + + private func saveCohort(_ cohort: CohortID, in experimentID: SubfeatureID, parentID: ParentFeatureID) { + var experiments = store.experiments ?? Experiments() + let experimentData = ExperimentData(parentID: parentID, cohortID: cohort, enrollmentDate: Date()) + experiments[experimentID] = experimentData + store.experiments = experiments + } +} diff --git a/Sources/BrowserServicesKit/PrivacyConfig/ExperimentsDataStore.swift b/Sources/BrowserServicesKit/FeatureFlagger/ExperimentsDataStore.swift similarity index 86% rename from Sources/BrowserServicesKit/PrivacyConfig/ExperimentsDataStore.swift rename to Sources/BrowserServicesKit/FeatureFlagger/ExperimentsDataStore.swift index bdf82819a..c99184df6 100644 --- a/Sources/BrowserServicesKit/PrivacyConfig/ExperimentsDataStore.swift +++ b/Sources/BrowserServicesKit/FeatureFlagger/ExperimentsDataStore.swift @@ -18,16 +18,16 @@ import Foundation -protocol ExperimentsDataStoring { +public protocol ExperimentsDataStoring { var experiments: Experiments? { get set } } -protocol LocalDataStoring { +public protocol LocalDataStoring { func data(forKey defaultName: String) -> Data? func set(_ value: Any?, forKey defaultName: String) } -struct ExperimentsDataStore: ExperimentsDataStoring { +public struct ExperimentsDataStore: ExperimentsDataStoring { private enum Constants { static let experimentsDataKey = "ExperimentsData" @@ -36,13 +36,13 @@ struct ExperimentsDataStore: ExperimentsDataStoring { private let decoder = JSONDecoder() private let encoder = JSONEncoder() - init(localDataStoring: LocalDataStoring = UserDefaults.standard) { + public init(localDataStoring: LocalDataStoring = UserDefaults.standard) { self.localDataStoring = localDataStoring encoder.dateEncodingStrategy = .secondsSince1970 decoder.dateDecodingStrategy = .secondsSince1970 } - var experiments: Experiments? { + public var experiments: Experiments? { get { guard let savedData = localDataStoring.data(forKey: Constants.experimentsDataKey) else { return nil } return try? decoder.decode(Experiments.self, from: savedData) diff --git a/Sources/BrowserServicesKit/FeatureFlagger/FeatureFlagger.swift b/Sources/BrowserServicesKit/FeatureFlagger/FeatureFlagger.swift index d6f5fdee4..3ecb53d1e 100644 --- a/Sources/BrowserServicesKit/FeatureFlagger/FeatureFlagger.swift +++ b/Sources/BrowserServicesKit/FeatureFlagger/FeatureFlagger.swift @@ -18,6 +18,8 @@ import Foundation +public protocol FlagCohort: RawRepresentable, CaseIterable where RawValue == CohortID {} + /// This protocol defines a common interface for feature flags managed by FeatureFlagger. /// /// It should be implemented by the feature flag type in client apps. @@ -53,7 +55,43 @@ public protocol FeatureFlagDescribing: CaseIterable { /// case .sync: /// return .disabled /// case .cookieConsent: - /// return .internalOnly + /// return .internalOnly() + /// case .credentialsAutofill: + /// return .remoteDevelopment(.subfeature(AutofillSubfeature.credentialsAutofill)) + /// case .duckPlayer: + /// return .remoteReleasable(.feature(.duckPlayer)) + /// } + /// } + /// ``` + var source: FeatureFlagSource { get } +} + +/// This protocol defines a common interface for experiment feature flags managed by FeatureFlagger. +/// +/// It should be implemented by the feature flag type in client apps. +/// +public protocol FeatureFlagExperimentDescribing { + + /// Returns a string representation of the flag + var rawValue: String { get } + + /// Defines the source of the experiment feature flag, which corresponds to + /// where the final flag value should come from. + /// + /// Example client implementation: + /// + /// ``` + /// public enum FeatureFlag: FeatureFlagDescribing { + /// case sync + /// case autofill + /// case cookieConsent + /// case duckPlayer + /// + /// var source: FeatureFlagSource { + /// case .sync: + /// return .disabled + /// case .cookieConsent: + /// return .internalOnly(cohort) /// case .credentialsAutofill: /// return .remoteDevelopment(.subfeature(AutofillSubfeature.credentialsAutofill)) /// case .duckPlayer: @@ -62,6 +100,29 @@ public protocol FeatureFlagDescribing: CaseIterable { /// } /// ``` var source: FeatureFlagSource { get } + + /// Represents the possible groups or variants within an experiment. + /// + /// The `Cohort` type is used to define user groups or test variations for feature + /// experimentation. Each cohort typically corresponds to a specific behavior or configuration + /// applied to a subset of users. For example, in an A/B test, you might define cohorts such as + /// `control` and `treatment`. + /// + /// Each cohort must conform to the `CohortEnum` protocol, which ensures that the cohort type + /// is an `enum` with `String` raw values and provides access to all possible cases + /// through `CaseIterable`. + /// + /// Example: + /// ``` + /// public enum AutofillCohorts: String, CohortEnum { + /// case control + /// case treatment + /// } + /// ``` + /// + /// The `Cohort` type allows dynamic resolution of cohorts by their raw `String` value, + /// making it easy to map user configurations to specific cohort groups. + associatedtype CohortType: FlagCohort } public enum FeatureFlagSource { @@ -69,7 +130,7 @@ public enum FeatureFlagSource { case disabled /// Enabled for internal users only. Cannot be toggled remotely - case internalOnly + case internalOnly((any FlagCohort)? = nil) /// Toggled remotely using PrivacyConfiguration but only for internal users. Otherwise, disabled. case remoteDevelopment(PrivacyConfigFeatureLevel) @@ -107,6 +168,44 @@ public protocol FeatureFlagger: AnyObject { /// when the non-overridden feature flag value is required. /// func isFeatureOn(for featureFlag: Flag, allowOverride: Bool) -> Bool + + /// Retrieves the cohort for a feature flag if the feature is enabled. + /// + /// This method determines the source of the feature flag and evaluates its eligibility based on + /// the user's internal status and the privacy configuration. It supports different sources, such as + /// disabled features, internal-only features, and remotely toggled features. + /// + /// - Parameter featureFlag: A feature flag conforming to `FeatureFlagDescribing`. + /// + /// - Returns: The `CohortID` associated with the feature flag, or `nil` if the feature is disabled or + /// does not meet the eligibility criteria. + /// + /// - Behavior: + /// - For `.disabled`: Returns `nil`. + /// - For `.internalOnly`: Returns the cohort if the user is an internal user. + /// - For `.remoteDevelopment` and `.remoteReleasable`: + /// - If the feature is a subfeature, resolves its cohort using `getCohortIfEnabled(_ subfeature:)`. + /// - Returns `nil` if the user is not eligible. + /// + func getCohortIfEnabled(for featureFlag: Flag) -> (any FlagCohort)? + + /// Retrieves all active experiments currently assigned to the user. + /// + /// This method iterates over the experiments stored in the `ExperimentManager` and checks their state + /// against the current `PrivacyConfiguration`. If an experiment's state is enabled or disabled due to + /// a target mismatch, and its assigned cohort matches the resolved cohort, it is considered active. + /// + /// - Returns: A dictionary of active experiments where the key is the experiment's subfeature ID, + /// and the value is the associated `ExperimentData`. + /// + /// - Behavior: + /// 1. Fetches all enrolled experiments from the `ExperimentManager`. + /// 2. For each experiment: + /// - Retrieves its state from the `PrivacyConfiguration`. + /// - Validates its assigned cohort using `resolveCohort` in the `ExperimentManager`. + /// 3. If the experiment passes validation, it is added to the result dictionary. + /// + func getAllActiveExperiments() -> Experiments } public extension FeatureFlagger { @@ -126,14 +225,17 @@ public class DefaultFeatureFlagger: FeatureFlagger { public let internalUserDecider: InternalUserDecider public let privacyConfigManager: PrivacyConfigurationManaging + private let experimentManager: ExperimentCohortsManaging? public let localOverrides: FeatureFlagLocalOverriding? public init( internalUserDecider: InternalUserDecider, - privacyConfigManager: PrivacyConfigurationManaging + privacyConfigManager: PrivacyConfigurationManaging, + experimentManager: ExperimentCohortsManaging? ) { self.internalUserDecider = internalUserDecider self.privacyConfigManager = privacyConfigManager + self.experimentManager = experimentManager self.localOverrides = nil } @@ -141,11 +243,13 @@ public class DefaultFeatureFlagger: FeatureFlagger { internalUserDecider: InternalUserDecider, privacyConfigManager: PrivacyConfigurationManaging, localOverrides: FeatureFlagLocalOverriding, + experimentManager: ExperimentCohortsManaging?, for: Flag.Type ) { self.internalUserDecider = internalUserDecider self.privacyConfigManager = privacyConfigManager self.localOverrides = localOverrides + self.experimentManager = experimentManager localOverrides.featureFlagger = self // Clear all overrides if not an internal user @@ -173,6 +277,58 @@ public class DefaultFeatureFlagger: FeatureFlagger { } } + public func getAllActiveExperiments() -> Experiments { + guard let enrolledExperiments = experimentManager?.experiments else { return [:] } + var activeExperiments = [String: ExperimentData]() + let config = privacyConfigManager.privacyConfig + + for (subfeatureID, experimentData) in enrolledExperiments { + let state = config.stateFor(subfeatureID: subfeatureID, parentFeatureID: experimentData.parentID) + guard state == .enabled || state == .disabled(.targetDoesNotMatch) else { continue } + let cohorts = config.cohorts(subfeatureID: subfeatureID, parentFeatureID: experimentData.parentID) ?? [] + let experimentSubfeature = ExperimentSubfeature(parentID: experimentData.parentID, subfeatureID: subfeatureID, cohorts: cohorts) + + if experimentManager?.resolveCohort(for: experimentSubfeature, allowCohortReassignment: false) == experimentData.cohortID { + activeExperiments[subfeatureID] = experimentData + } + } + return activeExperiments + } + + public func getCohortIfEnabled(for featureFlag: Flag) -> (any FlagCohort)? { + switch featureFlag.source { + case .disabled: + return nil + case .internalOnly(let cohort): + return cohort + case .remoteReleasable(let featureType), + .remoteDevelopment(let featureType) where internalUserDecider.isInternalUser: + if case .subfeature(let subfeature) = featureType { + if let resolvedCohortID = getCohortIfEnabled(subfeature) { + return Flag.CohortType.allCases.first { return $0.rawValue == resolvedCohortID } + } + } + return nil + default: + return nil + } + } + + private func getCohortIfEnabled(_ subfeature: any PrivacySubfeature) -> CohortID? { + let config = privacyConfigManager.privacyConfig + let featureState = config.stateFor(subfeature) + let cohorts = config.cohorts(for: subfeature) + let experiment = ExperimentSubfeature(parentID: subfeature.parent.rawValue, subfeatureID: subfeature.rawValue, cohorts: cohorts ?? []) + switch featureState { + case .enabled: + return experimentManager?.resolveCohort(for: experiment, allowCohortReassignment: true) + case .disabled(.targetDoesNotMatch): + return experimentManager?.resolveCohort(for: experiment, allowCohortReassignment: false) + default: + return nil + } + } + private func isEnabled(_ featureType: PrivacyConfigFeatureLevel) -> Bool { switch featureType { case .feature(let feature): diff --git a/Sources/BrowserServicesKit/PrivacyConfig/AppPrivacyConfiguration.swift b/Sources/BrowserServicesKit/PrivacyConfig/AppPrivacyConfiguration.swift index 89e64093d..2390aff4d 100644 --- a/Sources/BrowserServicesKit/PrivacyConfig/AppPrivacyConfiguration.swift +++ b/Sources/BrowserServicesKit/PrivacyConfig/AppPrivacyConfiguration.swift @@ -33,19 +33,23 @@ public struct AppPrivacyConfiguration: PrivacyConfiguration { private let locallyUnprotected: DomainsProtectionStore private let internalUserDecider: InternalUserDecider private let userDefaults: UserDefaults + private let locale: Locale private let installDate: Date? + static let experimentManagerQueue = DispatchQueue(label: "com.experimentManager.queue") public init(data: PrivacyConfigurationData, identifier: String, localProtection: DomainsProtectionStore, internalUserDecider: InternalUserDecider, userDefaults: UserDefaults = UserDefaults(), + locale: Locale = Locale.current, installDate: Date? = nil) { self.data = data self.identifier = identifier self.locallyUnprotected = localProtection self.internalUserDecider = internalUserDecider self.userDefaults = userDefaults + self.locale = locale self.installDate = installDate } @@ -137,13 +141,14 @@ public struct AppPrivacyConfiguration: PrivacyConfiguration { } } - private func isRolloutEnabled(subfeature: any PrivacySubfeature, + private func isRolloutEnabled(subfeatureID: SubfeatureID, + parentID: ParentFeatureID, rolloutSteps: [PrivacyConfigurationData.PrivacyFeature.Feature.RolloutStep], randomizer: (Range) -> Double) -> Bool { // Empty rollouts should be default enabled guard !rolloutSteps.isEmpty else { return true } - let defsPrefix = "config.\(subfeature.parent.rawValue).\(subfeature.rawValue)" + let defsPrefix = "config.\(parentID).\(subfeatureID)" if userDefaults.bool(forKey: "\(defsPrefix).\(Constants.enabledKey)") { return true } @@ -182,7 +187,6 @@ public struct AppPrivacyConfiguration: PrivacyConfiguration { public func isSubfeatureEnabled(_ subfeature: any PrivacySubfeature, versionProvider: AppVersionProvider, randomizer: (Range) -> Double) -> Bool { - switch stateFor(subfeature, versionProvider: versionProvider, randomizer: randomizer) { case .enabled: return true @@ -191,17 +195,30 @@ public struct AppPrivacyConfiguration: PrivacyConfiguration { } } - public func stateFor(_ subfeature: any PrivacySubfeature, versionProvider: AppVersionProvider, randomizer: (Range) -> Double) -> PrivacyConfigurationFeatureState { + public func stateFor(_ subfeature: any PrivacySubfeature, + versionProvider: AppVersionProvider, + randomizer: (Range) -> Double) -> PrivacyConfigurationFeatureState { + guard let subfeatureData = subfeatures(for: subfeature.parent)[subfeature.rawValue] else { + return .disabled(.featureMissing) + } - let parentState = stateFor(featureKey: subfeature.parent, versionProvider: versionProvider) - guard case .enabled = parentState else { return parentState } + return stateFor(subfeatureID: subfeature.rawValue, subfeatureData: subfeatureData, parentFeature: subfeature.parent, versionProvider: versionProvider, randomizer: randomizer) + } - let subfeatures = subfeatures(for: subfeature.parent) - let subfeatureData = subfeatures[subfeature.rawValue] + private func stateFor(subfeatureID: SubfeatureID, + subfeatureData: PrivacyConfigurationData.PrivacyFeature.Feature, + parentFeature: PrivacyFeature, + versionProvider: AppVersionProvider, + randomizer: (Range) -> Double) -> PrivacyConfigurationFeatureState { + // Step 1: Check parent feature state + let parentState = stateFor(featureKey: parentFeature, versionProvider: versionProvider) + guard case .enabled = parentState else { return parentState } - let satisfiesMinVersion = satisfiesMinVersion(subfeatureData?.minSupportedVersion, versionProvider: versionProvider) + // Step 2: Check version + let satisfiesMinVersion = satisfiesMinVersion(subfeatureData.minSupportedVersion, versionProvider: versionProvider) - switch subfeatureData?.state { + // Step 3: Check sub-feature state + switch subfeatureData.state { case PrivacyConfigurationData.State.enabled: guard satisfiesMinVersion else { return .disabled(.appVersionNotSupported) } case PrivacyConfigurationData.State.internal: @@ -210,15 +227,31 @@ public struct AppPrivacyConfiguration: PrivacyConfiguration { default: return .disabled(.disabledInConfig) } - // Handle Rollouts - if let rollout = subfeatureData?.rollout, - !isRolloutEnabled(subfeature: subfeature, rolloutSteps: rollout.steps, randomizer: randomizer) { + // Step 4: Handle Rollouts + if let rollout = subfeatureData.rollout, + !isRolloutEnabled(subfeatureID: subfeatureID, parentID: parentFeature.rawValue, rolloutSteps: rollout.steps, randomizer: randomizer) { return .disabled(.stillInRollout) } + // Step 5: Check Targets + return checkTargets(subfeatureData) + } + + private func checkTargets(_ subfeatureData: PrivacyConfigurationData.PrivacyFeature.Feature?) -> PrivacyConfigurationFeatureState { + // Check Targets + if let targets = subfeatureData?.targets, !matchTargets(targets: targets){ + return .disabled(.targetDoesNotMatch) + } return .enabled } + private func matchTargets(targets: [PrivacyConfigurationData.PrivacyFeature.Feature.Target]) -> Bool { + targets.contains { target in + (target.localeCountry == nil || target.localeCountry == locale.regionCode) && + (target.localeLanguage == nil || target.localeLanguage == locale.languageCode) + } + } + private func subfeatures(for feature: PrivacyFeature) -> PrivacyConfigurationData.PrivacyFeature.Features { return data.features[feature.rawValue]?.features ?? [:] } @@ -247,7 +280,7 @@ public struct AppPrivacyConfiguration: PrivacyConfiguration { guard let domain = domain else { return true } return !isTempUnprotected(domain: domain) && !isUserUnprotected(domain: domain) && - !isInExceptionList(domain: domain, forFeature: .contentBlocking) + !isInExceptionList(domain: domain, forFeature: .contentBlocking) } public func isUserUnprotected(domain: String?) -> Bool { @@ -297,7 +330,25 @@ public struct AppPrivacyConfiguration: PrivacyConfiguration { public func userDisabledProtection(forDomain domain: String) { locallyUnprotected.disableProtection(forDomain: domain.punycodeEncodedHostname.lowercased()) } +} +extension AppPrivacyConfiguration { + + public func stateFor(subfeatureID: SubfeatureID, parentFeatureID: ParentFeatureID, versionProvider: AppVersionProvider, + randomizer: (Range) -> Double) -> PrivacyConfigurationFeatureState { + guard let parentFeature = PrivacyFeature(rawValue: parentFeatureID) else { return .disabled(.featureMissing) } + guard let subfeatureData = subfeatures(for: parentFeature)[subfeatureID] else { return .disabled(.featureMissing) } + return stateFor(subfeatureID: subfeatureID, subfeatureData: subfeatureData, parentFeature: parentFeature, versionProvider: versionProvider, randomizer: randomizer) + } + + public func cohorts(for subfeature: any PrivacySubfeature) -> [PrivacyConfigurationData.Cohort]? { + subfeatures(for: subfeature.parent)[subfeature.rawValue]?.cohorts + } + + public func cohorts(subfeatureID: SubfeatureID, parentFeatureID: ParentFeatureID) -> [PrivacyConfigurationData.Cohort]? { + guard let parentFeature = PrivacyFeature(rawValue: parentFeatureID) else { return nil } + return subfeatures(for: parentFeature)[subfeatureID]?.cohorts + } } extension Array where Element == String { diff --git a/Sources/BrowserServicesKit/PrivacyConfig/ExperimentCohortsManager.swift b/Sources/BrowserServicesKit/PrivacyConfig/ExperimentCohortsManager.swift deleted file mode 100644 index abd01290b..000000000 --- a/Sources/BrowserServicesKit/PrivacyConfig/ExperimentCohortsManager.swift +++ /dev/null @@ -1,107 +0,0 @@ -// -// ExperimentCohortsManager.swift -// -// Copyright © 2024 DuckDuckGo. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// - -import Foundation - -struct ExperimentSubfeature { - let subfeatureID: SubfeatureID - let cohorts: [PrivacyConfigurationData.Cohort] -} - -typealias CohortID = String -typealias SubfeatureID = String - -struct ExperimentData: Codable, Equatable { - let cohort: String - let enrollmentDate: Date -} - -typealias Experiments = [String: ExperimentData] - -protocol ExperimentCohortsManaging { - /// Retrieves the cohort ID associated with the specified subfeature. - /// - Parameter subfeatureID: The name of the experiment subfeature for which the cohort ID is needed. - /// - Returns: The cohort ID as a `String` if one exists; otherwise, returns `nil`. - func cohort(for subfeatureID: SubfeatureID) -> CohortID? - - /// Retrieves the enrollment date for the specified subfeature. - /// - Parameter subfeatureID: The name of the experiment subfeature for which the enrollment date is needed. - /// - Returns: The `Date` of enrollment if one exists; otherwise, returns `nil`. - func enrollmentDate(for subfeatureID: SubfeatureID) -> Date? - - /// Assigns a cohort to the given subfeature based on defined weights and saves it to UserDefaults. - /// - Parameter subfeature: The ExperimentSubfeature to which a cohort needs to be assigned to. - /// - Returns: The name of the assigned cohort, or `nil` if no cohort could be assigned. - func assignCohort(to subfeature: ExperimentSubfeature) -> CohortID? - - /// Removes the assigned cohort data for the specified subfeature. - /// - Parameter subfeatureID: The name of the experiment subfeature for which the cohort data should be removed. - func removeCohort(from subfeatureID: SubfeatureID) -} - -final class ExperimentCohortsManager: ExperimentCohortsManaging { - - private var store: ExperimentsDataStoring - private let randomizer: (Range) -> Double - - init(store: ExperimentsDataStoring = ExperimentsDataStore(), randomizer: @escaping (Range) -> Double) { - self.store = store - self.randomizer = randomizer - } - - func cohort(for subfeatureID: SubfeatureID) -> CohortID? { - guard let experiments = store.experiments else { return nil } - return experiments[subfeatureID]?.cohort - } - - func enrollmentDate(for subfeatureID: SubfeatureID) -> Date? { - guard let experiments = store.experiments else { return nil } - return experiments[subfeatureID]?.enrollmentDate - } - - func assignCohort(to subfeature: ExperimentSubfeature) -> CohortID? { - let cohorts = subfeature.cohorts - let totalWeight = cohorts.map(\.weight).reduce(0, +) - guard totalWeight > 0 else { return nil } - - let randomValue = randomizer(0..) -> Double) -> PrivacyConfigurationFeatureState + + func cohorts(for subfeature: any PrivacySubfeature) -> [PrivacyConfigurationData.Cohort]? + + func cohorts(subfeatureID: SubfeatureID, parentFeatureID: ParentFeatureID) -> [PrivacyConfigurationData.Cohort]? } public extension PrivacyConfiguration { @@ -120,4 +129,9 @@ public extension PrivacyConfiguration { func stateFor(_ subfeature: any PrivacySubfeature, randomizer: (Range) -> Double = Double.random(in:)) -> PrivacyConfigurationFeatureState { return stateFor(subfeature, versionProvider: AppVersionProvider(), randomizer: randomizer) } + + func stateFor(subfeatureID: SubfeatureID, parentFeatureID: ParentFeatureID, randomizer: (Range) -> Double = Double.random(in:)) -> PrivacyConfigurationFeatureState { + return stateFor(subfeatureID: subfeatureID, parentFeatureID: parentFeatureID, versionProvider: AppVersionProvider(), randomizer: randomizer) + } + } diff --git a/Sources/BrowserServicesKit/PrivacyConfig/PrivacyConfigurationData.swift b/Sources/BrowserServicesKit/PrivacyConfig/PrivacyConfigurationData.swift index 7cbd2bc71..f6f91567f 100644 --- a/Sources/BrowserServicesKit/PrivacyConfig/PrivacyConfigurationData.swift +++ b/Sources/BrowserServicesKit/PrivacyConfig/PrivacyConfigurationData.swift @@ -136,6 +136,8 @@ public struct PrivacyConfigurationData { case minSupportedVersion case rollout case cohorts + case targets + case settings } public struct Rollout: Hashable { @@ -169,10 +171,27 @@ public struct PrivacyConfigurationData { } } + public struct Target { + enum CodingKeys: String { + case localeCountry + case localeLanguage + } + + public let localeCountry: String? + public let localeLanguage: String? + + public init(json: [String: Any]) { + self.localeCountry = json[CodingKeys.localeCountry.rawValue] as? String + self.localeLanguage = json[CodingKeys.localeLanguage.rawValue] as? String + } + } + public let state: FeatureState public let minSupportedVersion: FeatureSupportedVersion? public let rollout: Rollout? public let cohorts: [Cohort]? + public let targets: [Target]? + public let settings: String? public init?(json: [String: Any]) { guard let state = json[CodingKeys.state.rawValue] as? String else { @@ -194,6 +213,19 @@ public struct PrivacyConfigurationData { } else { cohorts = nil } + + if let targetData = json[CodingKeys.targets.rawValue] as? [[String: Any]] { + targets = targetData.compactMap { Target(json: $0) } + } else { + targets = nil + } + + if let settingsData = json[CodingKeys.settings.rawValue] { + let jsonData = try? JSONSerialization.data(withJSONObject: settingsData, options: []) + settings = jsonData != nil ? String(data: jsonData!, encoding: .utf8) : nil + } else { + settings = nil + } } } diff --git a/Sources/BrowserServicesKit/PrivacyConfig/PrivacyConfigurationManager.swift b/Sources/BrowserServicesKit/PrivacyConfig/PrivacyConfigurationManager.swift index 219a01613..004729e70 100644 --- a/Sources/BrowserServicesKit/PrivacyConfig/PrivacyConfigurationManager.swift +++ b/Sources/BrowserServicesKit/PrivacyConfig/PrivacyConfigurationManager.swift @@ -55,6 +55,7 @@ public class PrivacyConfigurationManager: PrivacyConfigurationManaging { private let localProtection: DomainsProtectionStore private let errorReporting: EventMapping? private let installDate: Date? + private let locale: Locale public let internalUserDecider: InternalUserDecider @@ -110,12 +111,14 @@ public class PrivacyConfigurationManager: PrivacyConfigurationManaging { localProtection: DomainsProtectionStore, errorReporting: EventMapping? = nil, internalUserDecider: InternalUserDecider, + locale: Locale = Locale.current, installDate: Date? = nil ) { self.embeddedDataProvider = embeddedDataProvider self.localProtection = localProtection self.errorReporting = errorReporting self.internalUserDecider = internalUserDecider + self.locale = locale self.installDate = installDate reload(etag: fetchedETag, data: fetchedData) @@ -127,6 +130,7 @@ public class PrivacyConfigurationManager: PrivacyConfigurationManaging { identifier: fetchedData.etag, localProtection: localProtection, internalUserDecider: internalUserDecider, + locale: locale, installDate: installDate) } @@ -134,6 +138,7 @@ public class PrivacyConfigurationManager: PrivacyConfigurationManaging { identifier: embeddedConfigData.etag, localProtection: localProtection, internalUserDecider: internalUserDecider, + locale: locale, installDate: installDate) } diff --git a/Sources/Common/Concurrency/TaskExtension.swift b/Sources/Common/Concurrency/TaskExtension.swift index a407e4406..65d974a36 100644 --- a/Sources/Common/Concurrency/TaskExtension.swift +++ b/Sources/Common/Concurrency/TaskExtension.swift @@ -18,51 +18,72 @@ import Foundation +public struct Sleeper { + + public static let `default` = Sleeper(sleep: { + try await Task.sleep(interval: $0) + }) + + private let sleep: (TimeInterval) async throws -> Void + + public init(sleep: @escaping (TimeInterval) async throws -> Void) { + self.sleep = sleep + } + + @available(macOS 13.0, iOS 16.0, *) + public init(clock: any Clock) { + self.sleep = { interval in + try await clock.sleep(for: .nanoseconds(UInt64(interval * Double(NSEC_PER_SEC)))) + } + } + + public func sleep(for interval: TimeInterval) async throws { + try await sleep(interval) + } + +} + +public func performPeriodicJob(withDelay delay: TimeInterval? = nil, + interval: TimeInterval, + sleeper: Sleeper = .default, + operation: @escaping @Sendable () async throws -> Void, + cancellationHandler: (@Sendable () async -> Void)? = nil) async throws -> Never { + + do { + if let delay { + try await sleeper.sleep(for: delay) + } + + repeat { + try await operation() + + try await sleeper.sleep(for: interval) + } while true + } catch let error as CancellationError { + await cancellationHandler?() + throw error + } +} + public extension Task where Success == Never, Failure == Error { static func periodic(delay: TimeInterval? = nil, interval: TimeInterval, + sleeper: Sleeper = .default, operation: @escaping @Sendable () async -> Void, cancellationHandler: (@Sendable () async -> Void)? = nil) -> Task { - Task { - do { - if let delay { - try await Task.sleep(interval: delay) - } - - repeat { - await operation() - - try await Task.sleep(interval: interval) - } while true - } catch { - await cancellationHandler?() - throw error - } - } + return periodic(delay: delay, interval: interval, sleeper: sleeper, operation: { await operation() } as @Sendable () async throws -> Void, cancellationHandler: cancellationHandler) } static func periodic(delay: TimeInterval? = nil, interval: TimeInterval, + sleeper: Sleeper = .default, operation: @escaping @Sendable () async throws -> Void, cancellationHandler: (@Sendable () async -> Void)? = nil) -> Task { Task { - do { - if let delay { - try await Task.sleep(interval: delay) - } - - repeat { - try await operation() - - try await Task.sleep(interval: interval) - } while true - } catch { - await cancellationHandler?() - throw error - } + try await performPeriodicJob(withDelay: delay, interval: interval, sleeper: sleeper, operation: operation, cancellationHandler: cancellationHandler) } } } diff --git a/Sources/Common/Extensions/DateExtension.swift b/Sources/Common/Extensions/DateExtension.swift index fdabdee83..2a4ca74ea 100644 --- a/Sources/Common/Extensions/DateExtension.swift +++ b/Sources/Common/Extensions/DateExtension.swift @@ -64,7 +64,11 @@ public extension Date { } var startOfDay: Date { - return Calendar.current.startOfDay(for: self) + return Calendar.current.startOfDay(for: self) + } + + func daysAgo(_ days: Int) -> Date { + Calendar.current.date(byAdding: .day, value: -days, to: self)! } static var startOfMinuteNow: Date { diff --git a/Sources/Common/Extensions/HashExtension.swift b/Sources/Common/Extensions/HashExtension.swift index b6752cf57..13095cf63 100644 --- a/Sources/Common/Extensions/HashExtension.swift +++ b/Sources/Common/Extensions/HashExtension.swift @@ -42,8 +42,13 @@ extension Data { extension String { public var sha1: String { - let dataBytes = data(using: .utf8)! - return dataBytes.sha1 + let result = utf8data.sha1 + return result + } + + public var sha256: String { + let result = utf8data.sha256 + return result } } diff --git a/Sources/Common/Extensions/StringExtension.swift b/Sources/Common/Extensions/StringExtension.swift index 09050cfe2..9282a43b4 100644 --- a/Sources/Common/Extensions/StringExtension.swift +++ b/Sources/Common/Extensions/StringExtension.swift @@ -394,9 +394,9 @@ public extension String { // MARK: Regex - func matches(_ regex: NSRegularExpression) -> Bool { - let matches = regex.matches(in: self, options: .anchored, range: self.fullRange) - return matches.count == 1 + func matches(_ regex: RegEx) -> Bool { + let firstMatch = firstMatch(of: regex, options: .anchored) + return firstMatch != nil } func matches(pattern: String, options: NSRegularExpression.Options = [.caseInsensitive]) -> Bool { @@ -406,7 +406,7 @@ public extension String { return matches(regex) } - func replacing(_ regex: NSRegularExpression, with replacement: String) -> String { + func replacing(_ regex: RegEx, with replacement: String) -> String { regex.stringByReplacingMatches(in: self, range: self.fullRange, withTemplate: replacement) } diff --git a/Sources/Common/Extensions/URLExtension.swift b/Sources/Common/Extensions/URLExtension.swift index d19751148..ce68773d5 100644 --- a/Sources/Common/Extensions/URLExtension.swift +++ b/Sources/Common/Extensions/URLExtension.swift @@ -354,22 +354,24 @@ extension URL { // MARK: - Parameters + @_disfavoredOverload // prefer ordered KeyValuePairs collection when `parameters` passed as a Dictionary literal to preserve order. public func appendingParameters(_ parameters: QueryParams, allowedReservedCharacters: CharacterSet? = nil) -> URL where QueryParams.Element == (key: String, value: String) { + let result = self.appending(percentEncodedQueryItems: parameters.map { name, value in + URLQueryItem(percentEncodingName: name, value: value, withAllowedCharacters: allowedReservedCharacters) + }) + return result + } - return parameters.reduce(self) { partialResult, parameter in - partialResult.appendingParameter( - name: parameter.key, - value: parameter.value, - allowedReservedCharacters: allowedReservedCharacters - ) - } + public func appendingParameters(_ parameters: KeyValuePairs, allowedReservedCharacters: CharacterSet? = nil) -> URL { + let result = self.appending(percentEncodedQueryItems: parameters.map { name, value in + URLQueryItem(percentEncodingName: name, value: value, withAllowedCharacters: allowedReservedCharacters) + }) + return result } public func appendingParameter(name: String, value: String, allowedReservedCharacters: CharacterSet? = nil) -> URL { - let queryItem = URLQueryItem(percentEncodingName: name, - value: value, - withAllowedCharacters: allowedReservedCharacters) + let queryItem = URLQueryItem(percentEncodingName: name, value: value, withAllowedCharacters: allowedReservedCharacters) return self.appending(percentEncodedQueryItem: queryItem) } @@ -378,13 +380,15 @@ extension URL { } public func appending(percentEncodedQueryItems: [URLQueryItem]) -> URL { - guard var components = URLComponents(url: self, resolvingAgainstBaseURL: true) else { return self } + guard !percentEncodedQueryItems.isEmpty, + var components = URLComponents(url: self, resolvingAgainstBaseURL: true) else { return self } var existingPercentEncodedQueryItems = components.percentEncodedQueryItems ?? [URLQueryItem]() existingPercentEncodedQueryItems.append(contentsOf: percentEncodedQueryItems) components.percentEncodedQueryItems = existingPercentEncodedQueryItems + let result = components.url ?? self - return components.url ?? self + return result } public func getQueryItems() -> [URLQueryItem]? { diff --git a/Sources/MaliciousSiteProtection/API/APIClient.swift b/Sources/MaliciousSiteProtection/API/APIClient.swift new file mode 100644 index 000000000..d16669d4b --- /dev/null +++ b/Sources/MaliciousSiteProtection/API/APIClient.swift @@ -0,0 +1,129 @@ +// +// APIClient.swift +// +// Copyright © 2024 DuckDuckGo. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +import Common +import Foundation +import Networking + +extension APIClient { + // used internally for testing + protocol Mockable { + func load(_ requestConfig: Request) async throws -> Request.Response + } +} +extension APIClient: APIClient.Mockable {} + +public protocol APIClientEnvironment { + func headers(for requestType: APIRequestType) -> APIRequestV2.HeadersV2 + func url(for requestType: APIRequestType) -> URL +} + +public extension MaliciousSiteDetector { + enum APIEnvironment: APIClientEnvironment { + + case production + case staging + + var endpoint: URL { + switch self { + case .production: URL(string: "https://duckduckgo.com/api/protection/")! + case .staging: URL(string: "https://staging.duckduckgo.com/api/protection/")! + } + } + + var defaultHeaders: APIRequestV2.HeadersV2 { + .init(userAgent: Networking.APIRequest.Headers.userAgent) + } + + enum APIPath { + static let filterSet = "filterSet" + static let hashPrefix = "hashPrefix" + static let matches = "matches" + } + + enum QueryParameter { + static let category = "category" + static let revision = "revision" + static let hashPrefix = "hashPrefix" + } + + public func url(for requestType: APIRequestType) -> URL { + switch requestType { + case .hashPrefixSet(let configuration): + endpoint.appendingPathComponent(APIPath.hashPrefix).appendingParameters([ + QueryParameter.category: configuration.threatKind.rawValue, + QueryParameter.revision: (configuration.revision ?? 0).description, + ]) + case .filterSet(let configuration): + endpoint.appendingPathComponent(APIPath.filterSet).appendingParameters([ + QueryParameter.category: configuration.threatKind.rawValue, + QueryParameter.revision: (configuration.revision ?? 0).description, + ]) + case .matches(let configuration): + endpoint.appendingPathComponent(APIPath.matches).appendingParameter(name: QueryParameter.hashPrefix, value: configuration.hashPrefix) + } + } + + public func headers(for requestType: APIRequestType) -> APIRequestV2.HeadersV2 { + defaultHeaders + } + } + +} + +struct APIClient { + + let environment: APIClientEnvironment + private let service: APIService + + init(environment: APIClientEnvironment, service: APIService = DefaultAPIService(urlSession: .shared)) { + self.environment = environment + self.service = service + } + + func load(_ requestConfig: R) async throws -> R.Response { + let requestType = requestConfig.requestType + let headers = environment.headers(for: requestType) + let url = environment.url(for: requestType) + + let apiRequest = APIRequestV2(url: url, method: .get, headers: headers, timeoutInterval: requestConfig.timeout ?? 60) + let response = try await service.fetch(request: apiRequest) + let result: R.Response = try response.decodeBody() + + return result + } + +} + +// MARK: - Convenience +extension APIClient.Mockable { + func filtersChangeSet(for threatKind: ThreatKind, revision: Int) async throws -> APIClient.Response.FiltersChangeSet { + let result = try await load(.filterSet(threatKind: threatKind, revision: revision)) + return result + } + + func hashPrefixesChangeSet(for threatKind: ThreatKind, revision: Int) async throws -> APIClient.Response.HashPrefixesChangeSet { + let result = try await load(.hashPrefixes(threatKind: threatKind, revision: revision)) + return result + } + + func matches(forHashPrefix hashPrefix: String) async throws -> APIClient.Response.Matches { + let result = try await load(.matches(hashPrefix: hashPrefix)) + return result + } +} diff --git a/Sources/MaliciousSiteProtection/API/APIRequest.swift b/Sources/MaliciousSiteProtection/API/APIRequest.swift new file mode 100644 index 000000000..39fb623bd --- /dev/null +++ b/Sources/MaliciousSiteProtection/API/APIRequest.swift @@ -0,0 +1,112 @@ +// +// APIRequest.swift +// +// Copyright © 2024 DuckDuckGo. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +import Foundation + +// Enumerated request type to delegate URLs forming to an API environment instance +public enum APIRequestType { + case hashPrefixSet(APIRequestType.HashPrefixes) + case filterSet(APIRequestType.FilterSet) + case matches(APIRequestType.Matches) +} + +extension APIClient { + // Protocol for defining typed requests with a specific response type. + protocol Request { + associatedtype Response: Decodable // Strongly-typed response type + var requestType: APIRequestType { get } // Enumerated type of request being made + var timeout: TimeInterval? { get } + } + + // Protocol for requests that modify a set of malicious site detection data + // (returning insertions/removals along with the updated revision) + protocol ChangeSetRequest: Request { + init(threatKind: ThreatKind, revision: Int?) + } +} +extension APIClient.Request { + var timeout: TimeInterval? { nil } +} + +public extension APIRequestType { + struct HashPrefixes: APIClient.ChangeSetRequest { + typealias Response = APIClient.Response.HashPrefixesChangeSet + + let threatKind: ThreatKind + let revision: Int? + + init(threatKind: ThreatKind, revision: Int?) { + self.threatKind = threatKind + self.revision = revision + } + + var requestType: APIRequestType { + .hashPrefixSet(self) + } + } +} +/// extension to call generic `load(_: some Request)` method like this: `load(.hashPrefixes(…))` +extension APIClient.Request where Self == APIRequestType.HashPrefixes { + static func hashPrefixes(threatKind: ThreatKind, revision: Int?) -> Self { + .init(threatKind: threatKind, revision: revision) + } +} + +public extension APIRequestType { + struct FilterSet: APIClient.ChangeSetRequest { + typealias Response = APIClient.Response.FiltersChangeSet + + let threatKind: ThreatKind + let revision: Int? + + init(threatKind: ThreatKind, revision: Int?) { + self.threatKind = threatKind + self.revision = revision + } + + var requestType: APIRequestType { + .filterSet(self) + } + } +} +/// extension to call generic `load(_: some Request)` method like this: `load(.filterSet(…))` +extension APIClient.Request where Self == APIRequestType.FilterSet { + static func filterSet(threatKind: ThreatKind, revision: Int?) -> Self { + .init(threatKind: threatKind, revision: revision) + } +} + +public extension APIRequestType { + struct Matches: APIClient.Request { + typealias Response = APIClient.Response.Matches + + let hashPrefix: String + + var requestType: APIRequestType { + .matches(self) + } + + var timeout: TimeInterval? { 1 } + } +} +/// extension to call generic `load(_: some Request)` method like this: `load(.matches(…))` +extension APIClient.Request where Self == APIRequestType.Matches { + static func matches(hashPrefix: String) -> Self { + .init(hashPrefix: hashPrefix) + } +} diff --git a/Sources/MaliciousSiteProtection/API/ChangeSetResponse.swift b/Sources/MaliciousSiteProtection/API/ChangeSetResponse.swift new file mode 100644 index 000000000..eaf4f287c --- /dev/null +++ b/Sources/MaliciousSiteProtection/API/ChangeSetResponse.swift @@ -0,0 +1,47 @@ +// +// ChangeSetResponse.swift +// +// Copyright © 2024 DuckDuckGo. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +import Foundation + +extension APIClient { + + public struct ChangeSetResponse: Codable, Equatable { + let insert: [T] + let delete: [T] + let revision: Int + let replace: Bool + + public init(insert: [T], delete: [T], revision: Int, replace: Bool) { + self.insert = insert + self.delete = delete + self.revision = revision + self.replace = replace + } + + public var isEmpty: Bool { + insert.isEmpty && delete.isEmpty + } + } + + public enum Response { + public typealias FiltersChangeSet = ChangeSetResponse + public typealias HashPrefixesChangeSet = ChangeSetResponse + public typealias Matches = MatchResponse + } + +} diff --git a/Sources/MaliciousSiteProtection/API/MatchResponse.swift b/Sources/MaliciousSiteProtection/API/MatchResponse.swift new file mode 100644 index 000000000..2cb6df962 --- /dev/null +++ b/Sources/MaliciousSiteProtection/API/MatchResponse.swift @@ -0,0 +1,29 @@ +// +// MatchResponse.swift +// +// Copyright © 2024 DuckDuckGo. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +extension APIClient { + + public struct MatchResponse: Codable, Equatable { + public var matches: [Match] + + public init(matches: [Match]) { + self.matches = matches + } + } + +} diff --git a/Sources/Networking/v2/Extensions/Dictionary+URLQueryItem.swift b/Sources/MaliciousSiteProtection/Logger+MaliciousSiteProtection.swift similarity index 53% rename from Sources/Networking/v2/Extensions/Dictionary+URLQueryItem.swift rename to Sources/MaliciousSiteProtection/Logger+MaliciousSiteProtection.swift index 81a4648d6..3e44f3bcd 100644 --- a/Sources/Networking/v2/Extensions/Dictionary+URLQueryItem.swift +++ b/Sources/MaliciousSiteProtection/Logger+MaliciousSiteProtection.swift @@ -1,5 +1,5 @@ // -// Dictionary+URLQueryItem.swift +// Logger+MaliciousSiteProtection.swift // // Copyright © 2024 DuckDuckGo. All rights reserved. // @@ -17,19 +17,15 @@ // import Foundation -import Common +import os -extension Dictionary where Key == String, Value == String { - - func toURLQueryItems(allowedReservedCharacters: CharacterSet? = nil) -> [URLQueryItem] { - return self.map { - if let allowedReservedCharacters { - URLQueryItem(percentEncodingName: $0.key, - value: $0.value, - withAllowedCharacters: allowedReservedCharacters) - } else { - URLQueryItem(name: $0.key, value: $0.value) - } - } +public extension os.Logger { + struct MaliciousSiteProtection { + public static var general = os.Logger(subsystem: "MSP", category: "General") + public static var api = os.Logger(subsystem: "MSP", category: "API") + public static var dataManager = os.Logger(subsystem: "MSP", category: "DataManager") + public static var updateManager = os.Logger(subsystem: "MSP", category: "UpdateManager") } } + +internal typealias Logger = os.Logger.MaliciousSiteProtection diff --git a/Sources/MaliciousSiteProtection/MaliciousSiteDetector.swift b/Sources/MaliciousSiteProtection/MaliciousSiteDetector.swift new file mode 100644 index 000000000..1a637a73a --- /dev/null +++ b/Sources/MaliciousSiteProtection/MaliciousSiteDetector.swift @@ -0,0 +1,130 @@ +// +// MaliciousSiteDetector.swift +// +// Copyright © 2024 DuckDuckGo. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +import Common +import CryptoKit +import Foundation +import Networking + +public protocol MaliciousSiteDetecting { + /// Evaluates the given URL to determine its malicious category (e.g., phishing, malware). + /// - Parameter url: The URL to evaluate. + /// - Returns: An optional `ThreatKind` indicating the type of threat, or `.none` if no threat is detected. + func evaluate(_ url: URL) async -> ThreatKind? +} + +/// Class responsible for detecting malicious sites by evaluating URLs against local filters and an external API. +/// entry point: `func evaluate(_: URL) async -> ThreatKind?` +public final class MaliciousSiteDetector: MaliciousSiteDetecting { + // Type aliases for easier symbol navigation in Xcode. + typealias PhishingDetector = MaliciousSiteDetector + typealias MalwareDetector = MaliciousSiteDetector + + private enum Constants { + static let hashPrefixStoreLength: Int = 8 + static let hashPrefixParamLength: Int = 4 + } + + private let apiClient: APIClient.Mockable + private let dataManager: DataManaging + private let eventMapping: EventMapping + + public convenience init(apiEnvironment: APIClientEnvironment, service: APIService = DefaultAPIService(urlSession: .shared), dataManager: DataManager, eventMapping: EventMapping) { + self.init(apiClient: APIClient(environment: apiEnvironment, service: service), dataManager: dataManager, eventMapping: eventMapping) + } + + init(apiClient: APIClient.Mockable, dataManager: DataManaging, eventMapping: EventMapping) { + self.apiClient = apiClient + self.dataManager = dataManager + self.eventMapping = eventMapping + } + + private func checkLocalFilters(hostHash: String, canonicalUrl: URL, for threatKind: ThreatKind) async -> Bool { + let filterSet = await dataManager.dataSet(for: .filterSet(threatKind: threatKind)) + let matchesLocalFilters = filterSet[hostHash]?.contains(where: { regex in + canonicalUrl.absoluteString.matches(pattern: regex) + }) ?? false + + return matchesLocalFilters + } + + private func checkApiMatches(hostHash: String, canonicalUrl: URL) async -> Match? { + let hashPrefixParam = String(hostHash.prefix(Constants.hashPrefixParamLength)) + let matches: [Match] + do { + matches = try await apiClient.matches(forHashPrefix: hashPrefixParam).matches + } catch { + Logger.general.error("Error fetching matches from API: \(error)") + return nil + } + + if let match = matches.first(where: { match in + match.hash == hostHash && canonicalUrl.absoluteString.matches(pattern: match.regex) + }) { + return match + } + return nil + } + + /// Evaluates the given URL to determine its malicious category (e.g., phishing, malware). + public func evaluate(_ url: URL) async -> ThreatKind? { + guard let canonicalHost = url.canonicalHost(), + let canonicalUrl = url.canonicalURL() else { return .none } + + let hostHash = canonicalHost.sha256 + let hashPrefix = String(hostHash.prefix(Constants.hashPrefixStoreLength)) + + // 1. Check for matching hash prefixes. + // The hash prefix list serves as a representation of the entire database: + // every malicious website will have a hash prefix that it collides with. + var hashPrefixMatchingThreatKinds = [ThreatKind]() + for threatKind in ThreatKind.allCases { // e.g., phishing, malware, etc. + let hashPrefixes = await dataManager.dataSet(for: .hashPrefixes(threatKind: threatKind)) + if hashPrefixes.contains(hashPrefix) { + hashPrefixMatchingThreatKinds.append(threatKind) + } + } + + // Return no threats if no matching hash prefixes are found in the database. + guard !hashPrefixMatchingThreatKinds.isEmpty else { return .none } + + // 2. Check local Filter Sets. + // The filter set acts as a local cache of some database entries, containing + // the 5000 most common threats (or those most likely to collide with daily + // browsing behaviors, based on Clickhouse's top 10k, ranked by Netcraft's risk rating). + for threatKind in hashPrefixMatchingThreatKinds { + let matches = await checkLocalFilters(hostHash: hostHash, canonicalUrl: canonicalUrl, for: threatKind) + if matches { + eventMapping.fire(.errorPageShown(clientSideHit: true, threatKind: threatKind)) + return threatKind + } + } + + // 3. If no locally cached filters matched, we will still make a request to the API + // to check for potential matches on our backend. + let match = await checkApiMatches(hostHash: hostHash, canonicalUrl: canonicalUrl) + if let match { + let threatKind = match.category.flatMap(ThreatKind.init) ?? hashPrefixMatchingThreatKinds[0] + eventMapping.fire(.errorPageShown(clientSideHit: false, threatKind: threatKind)) + return threatKind + } + + return .none + } + +} diff --git a/Sources/PhishingDetection/PhishingDetectionEvents.swift b/Sources/MaliciousSiteProtection/Model/Event.swift similarity index 92% rename from Sources/PhishingDetection/PhishingDetectionEvents.swift rename to Sources/MaliciousSiteProtection/Model/Event.swift index a788e09ff..8903f4d70 100644 --- a/Sources/PhishingDetection/PhishingDetectionEvents.swift +++ b/Sources/MaliciousSiteProtection/Model/Event.swift @@ -1,5 +1,5 @@ // -// PhishingDetectionEvents.swift +// Event.swift // // Copyright © 2024 DuckDuckGo. All rights reserved. // @@ -26,8 +26,8 @@ public extension PixelKit { } } -public enum PhishingDetectionEvents: PixelKitEventV2 { - case errorPageShown(clientSideHit: Bool) +public enum Event: PixelKitEventV2 { + case errorPageShown(clientSideHit: Bool, threatKind: ThreatKind) case visitSite case iframeLoaded case updateTaskFailed48h(error: Error?) @@ -50,7 +50,7 @@ public enum PhishingDetectionEvents: PixelKitEventV2 { public var parameters: [String: String]? { switch self { - case .errorPageShown(let clientSideHit): + case .errorPageShown(let clientSideHit, threatKind: _): return [PixelKit.Parameters.clientSideHit: String(clientSideHit)] case .visitSite: return [:] diff --git a/Tests/PhishingDetectionTests/Mocks/BackgroundActivitySchedulerMock.swift b/Sources/MaliciousSiteProtection/Model/Filter.swift similarity index 65% rename from Tests/PhishingDetectionTests/Mocks/BackgroundActivitySchedulerMock.swift rename to Sources/MaliciousSiteProtection/Model/Filter.swift index 86b79d477..674a176e0 100644 --- a/Tests/PhishingDetectionTests/Mocks/BackgroundActivitySchedulerMock.swift +++ b/Sources/MaliciousSiteProtection/Model/Filter.swift @@ -1,5 +1,5 @@ // -// BackgroundActivitySchedulerMock.swift +// Filter.swift // // Copyright © 2024 DuckDuckGo. All rights reserved. // @@ -17,19 +17,18 @@ // import Foundation -import PhishingDetection -actor MockBackgroundActivityScheduler: BackgroundActivityScheduling { - var startCalled = false - var stopCalled = false - var interval: TimeInterval = 1 - var identifier: String = "test" +public struct Filter: Codable, Hashable { + public var hash: String + public var regex: String - func start() { - startCalled = true + enum CodingKeys: String, CodingKey { + case hash + case regex } - func stop() { - stopCalled = true + public init(hash: String, regex: String) { + self.hash = hash + self.regex = regex } } diff --git a/Sources/MaliciousSiteProtection/Model/FilterDictionary.swift b/Sources/MaliciousSiteProtection/Model/FilterDictionary.swift new file mode 100644 index 000000000..b67cd82ef --- /dev/null +++ b/Sources/MaliciousSiteProtection/Model/FilterDictionary.swift @@ -0,0 +1,78 @@ +// +// FilterDictionary.swift +// +// Copyright © 2024 DuckDuckGo. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +import Foundation + +struct FilterDictionary: Codable, Equatable { + + /// Filter set revision + var revision: Int + + /// [Hash: [RegEx]] mapping + /// + /// - **Key**: SHA256 hash sum of a canonical host name + /// - **Value**: An array of regex patterns used to match whole URLs + /// + /// ``` + /// { + /// "3aeb002460381c6f258e8395d3026f571f0d9a76488dcd837639b13aed316560" : [ + /// "(?i)^https?\\:\\/\\/[\\w\\-\\.]+(?:\\:(?:80|443))?[\\/\\\\]+BETS1O\\-GIRIS[\\/\\\\]+BETS1O(?:[\\/\\\\]+|\\?|$)" + /// ], + /// ... + /// } + /// ``` + var filters: [String: Set] + + /// Subscript to access regex patterns by SHA256 host name hash + subscript(hash: String) -> Set? { + filters[hash] + } + + mutating func subtract(_ itemsToDelete: Seq) where Seq.Element == Filter { + for filter in itemsToDelete { + // Remove the filter from the Set stored in the Dictionary by hash used as a key. + // If the Set becomes empty – remove the Set value from the Dictionary. + // + // The following code is equivalent to this one but without the Set value being copied + // or key being searched multiple times: + /* + if var filterSet = self.filters[filter.hash] { + filterSet.remove(filter.regex) + if filterSet.isEmpty { + self.filters[filter.hash] = nil + } else { + self.filters[filter.hash] = filterSet + } + } + */ + withUnsafeMutablePointer(to: &filters[filter.hash]) { item in + item.pointee?.remove(filter.regex) + if item.pointee?.isEmpty == true { + item.pointee = nil + } + } + } + } + + mutating func formUnion(_ itemsToAdd: Seq) where Seq.Element == Filter { + for filter in itemsToAdd { + filters[filter.hash, default: []].insert(filter.regex) + } + } + +} diff --git a/Sources/MaliciousSiteProtection/Model/HashPrefixSet.swift b/Sources/MaliciousSiteProtection/Model/HashPrefixSet.swift new file mode 100644 index 000000000..7aec5244d --- /dev/null +++ b/Sources/MaliciousSiteProtection/Model/HashPrefixSet.swift @@ -0,0 +1,45 @@ +// +// HashPrefixSet.swift +// +// Copyright © 2024 DuckDuckGo. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +import Foundation + +/// Structure storing a Set of hash prefixes ["6fe1e7c8","1d760415",...] and a revision of the set. +struct HashPrefixSet: Codable, Equatable { + + var revision: Int + var set: Set + + init(revision: Int, items: some Sequence) { + self.revision = revision + self.set = Set(items) + } + + mutating func subtract(_ itemsToDelete: Seq) where Seq.Element == String { + set.subtract(itemsToDelete) + } + + mutating func formUnion(_ itemsToAdd: Seq) where Seq.Element == String { + set.formUnion(itemsToAdd) + } + + @inline(__always) + func contains(_ item: String) -> Bool { + set.contains(item) + } + +} diff --git a/Sources/MaliciousSiteProtection/Model/IncrementallyUpdatableDataSet.swift b/Sources/MaliciousSiteProtection/Model/IncrementallyUpdatableDataSet.swift new file mode 100644 index 000000000..8a23785ae --- /dev/null +++ b/Sources/MaliciousSiteProtection/Model/IncrementallyUpdatableDataSet.swift @@ -0,0 +1,71 @@ +// +// IncrementallyUpdatableDataSet.swift +// +// Copyright © 2024 DuckDuckGo. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +protocol IncrementallyUpdatableDataSet: Codable, Equatable { + /// Set Element Type (Hash Prefix or Filter) + associatedtype Element: Codable, Hashable + /// API Request type used to fetch updates for the data set + associatedtype APIRequest: APIClient.ChangeSetRequest where APIRequest.Response == APIClient.ChangeSetResponse + + var revision: Int { get set } + + init(revision: Int, items: some Sequence) + + mutating func subtract(_ itemsToDelete: Seq) where Seq.Element == Element + mutating func formUnion(_ itemsToAdd: Seq) where Seq.Element == Element + + /// Apply ChangeSet from local data revision to actual revision loaded from API + mutating func apply(_ changeSet: APIClient.ChangeSetResponse) +} + +extension IncrementallyUpdatableDataSet { + mutating func apply(_ changeSet: APIClient.ChangeSetResponse) { + if changeSet.replace { + self = .init(revision: changeSet.revision, items: changeSet.insert) + } else { + self.subtract(changeSet.delete) + self.formUnion(changeSet.insert) + self.revision = changeSet.revision + } + } +} + +extension HashPrefixSet: IncrementallyUpdatableDataSet { + typealias Element = String + typealias APIRequest = APIRequestType.HashPrefixes + + static func apiRequest(for threatKind: ThreatKind, revision: Int) -> APIRequest { + .hashPrefixes(threatKind: threatKind, revision: revision) + } +} + +extension FilterDictionary: IncrementallyUpdatableDataSet { + typealias Element = Filter + typealias APIRequest = APIRequestType.FilterSet + + init(revision: Int, items: some Sequence) { + let filtersDictionary = items.reduce(into: [String: Set]()) { result, filter in + result[filter.hash, default: []].insert(filter.regex) + } + self.init(revision: revision, filters: filtersDictionary) + } + + static func apiRequest(for threatKind: ThreatKind, revision: Int) -> APIRequest { + .filterSet(threatKind: threatKind, revision: revision) + } +} diff --git a/Sources/MaliciousSiteProtection/Model/LoadableFromEmbeddedData.swift b/Sources/MaliciousSiteProtection/Model/LoadableFromEmbeddedData.swift new file mode 100644 index 000000000..be67cb6fc --- /dev/null +++ b/Sources/MaliciousSiteProtection/Model/LoadableFromEmbeddedData.swift @@ -0,0 +1,34 @@ +// +// LoadableFromEmbeddedData.swift +// +// Copyright © 2024 DuckDuckGo. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +public protocol LoadableFromEmbeddedData { + /// Set Element Type (Hash Prefix or Filter) + associatedtype Element + /// Decoded data type stored in the embedded json file + associatedtype EmbeddedDataSet: Decodable, Sequence where EmbeddedDataSet.Element == Self.Element + + init(revision: Int, items: some Sequence) +} + +extension HashPrefixSet: LoadableFromEmbeddedData { + public typealias EmbeddedDataSet = [String] +} + +extension FilterDictionary: LoadableFromEmbeddedData { + public typealias EmbeddedDataSet = [Filter] +} diff --git a/Sources/MaliciousSiteProtection/Model/MaliciousSiteError.swift b/Sources/MaliciousSiteProtection/Model/MaliciousSiteError.swift new file mode 100644 index 000000000..8da2523f5 --- /dev/null +++ b/Sources/MaliciousSiteProtection/Model/MaliciousSiteError.swift @@ -0,0 +1,94 @@ +// +// MaliciousSiteError.swift +// +// Copyright © 2024 DuckDuckGo. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +import Foundation + +public struct MaliciousSiteError: Error, Equatable { + + public enum Code: Int { + case phishing = 1 + // case malware = 2 + } + public let code: Code + public let failingUrl: URL + + public init(code: Code, failingUrl: URL) { + self.code = code + self.failingUrl = failingUrl + } + + public init(threat: ThreatKind, failingUrl: URL) { + let code: Code + switch threat { + case .phishing: + code = .phishing + // case .malware: + // code = .malware + } + self.init(code: code, failingUrl: failingUrl) + } + + public var threatKind: ThreatKind { + switch code { + case .phishing: .phishing + // case .malware: .malware + } + } + +} + +extension MaliciousSiteError: _ObjectiveCBridgeableError { + + public init?(_bridgedNSError error: NSError) { + guard error.domain == MaliciousSiteError.errorDomain, + let code = Code(rawValue: error.code), + let failingUrl = error.userInfo[NSURLErrorFailingURLErrorKey] as? URL else { return nil } + self.code = code + self.failingUrl = failingUrl + } + +} + +extension MaliciousSiteError: LocalizedError { + + public var errorDescription: String? { + switch code { + case .phishing: + return "Phishing detected" + // case .malware: + // return "Malware detected" + } + } + +} + +extension MaliciousSiteError: CustomNSError { + public static let errorDomain: String = "MaliciousSiteError" + + public var errorCode: Int { + code.rawValue + } + + public var errorUserInfo: [String: Any] { + [ + NSURLErrorFailingURLErrorKey: failingUrl, + NSLocalizedDescriptionKey: errorDescription! + ] + } + +} diff --git a/Sources/MaliciousSiteProtection/Model/Match.swift b/Sources/MaliciousSiteProtection/Model/Match.swift new file mode 100644 index 000000000..e22cb597f --- /dev/null +++ b/Sources/MaliciousSiteProtection/Model/Match.swift @@ -0,0 +1,35 @@ +// +// Match.swift +// +// Copyright © 2024 DuckDuckGo. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +import Foundation + +public struct Match: Codable, Hashable { + var hostname: String + var url: String + var regex: String + var hash: String + let category: String? + + public init(hostname: String, url: String, regex: String, hash: String, category: String?) { + self.hostname = hostname + self.url = url + self.regex = regex + self.hash = hash + self.category = category + } +} diff --git a/Sources/MaliciousSiteProtection/Model/StoredData.swift b/Sources/MaliciousSiteProtection/Model/StoredData.swift new file mode 100644 index 000000000..a064be076 --- /dev/null +++ b/Sources/MaliciousSiteProtection/Model/StoredData.swift @@ -0,0 +1,104 @@ +// +// StoredData.swift +// +// Copyright © 2024 DuckDuckGo. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +import Foundation + +protocol MaliciousSiteDataKey: Hashable { + associatedtype EmbeddedDataSet: Decodable + associatedtype DataSet: IncrementallyUpdatableDataSet, LoadableFromEmbeddedData + + var dataType: DataManager.StoredDataType { get } + var threatKind: ThreatKind { get } +} + +public extension DataManager { + enum StoredDataType: Hashable, CaseIterable { + case hashPrefixSet(HashPrefixes) + case filterSet(FilterSet) + + enum Kind: CaseIterable { + case hashPrefixSet, filterSet + } + // keep to get a compiler error when number of cases changes + var kind: Kind { + switch self { + case .hashPrefixSet: .hashPrefixSet + case .filterSet: .filterSet + } + } + + var dataKey: any MaliciousSiteDataKey { + switch self { + case .hashPrefixSet(let key): key + case .filterSet(let key): key + } + } + + public var threatKind: ThreatKind { + switch self { + case .hashPrefixSet(let key): key.threatKind + case .filterSet(let key): key.threatKind + } + } + + public static var allCases: [DataManager.StoredDataType] { + ThreatKind.allCases.map { threatKind in + Kind.allCases.map { dataKind in + switch dataKind { + case .hashPrefixSet: .hashPrefixSet(.init(threatKind: threatKind)) + case .filterSet: .filterSet(.init(threatKind: threatKind)) + } + } + }.flatMap { $0 } + } + } +} + +public extension DataManager.StoredDataType { + struct HashPrefixes: MaliciousSiteDataKey { + typealias DataSet = HashPrefixSet + + let threatKind: ThreatKind + + var dataType: DataManager.StoredDataType { + .hashPrefixSet(self) + } + } +} +extension MaliciousSiteDataKey where Self == DataManager.StoredDataType.HashPrefixes { + static func hashPrefixes(threatKind: ThreatKind) -> Self { + .init(threatKind: threatKind) + } +} + +public extension DataManager.StoredDataType { + struct FilterSet: MaliciousSiteDataKey { + typealias DataSet = FilterDictionary + + let threatKind: ThreatKind + + var dataType: DataManager.StoredDataType { + .filterSet(self) + } + } +} +extension MaliciousSiteDataKey where Self == DataManager.StoredDataType.FilterSet { + static func filterSet(threatKind: ThreatKind) -> Self { + .init(threatKind: threatKind) + } +} diff --git a/Sources/MaliciousSiteProtection/Model/ThreatKind.swift b/Sources/MaliciousSiteProtection/Model/ThreatKind.swift new file mode 100644 index 000000000..bec9e2996 --- /dev/null +++ b/Sources/MaliciousSiteProtection/Model/ThreatKind.swift @@ -0,0 +1,27 @@ +// +// ThreatKind.swift +// +// Copyright © 2024 DuckDuckGo. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +import Foundation + +public enum ThreatKind: String, CaseIterable, Codable, CustomStringConvertible { + public var description: String { rawValue } + + case phishing + // case malware + +} diff --git a/Sources/MaliciousSiteProtection/Services/DataManager.swift b/Sources/MaliciousSiteProtection/Services/DataManager.swift new file mode 100644 index 000000000..8e4426dd1 --- /dev/null +++ b/Sources/MaliciousSiteProtection/Services/DataManager.swift @@ -0,0 +1,105 @@ +// +// DataManager.swift +// +// Copyright © 2024 DuckDuckGo. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +import Foundation +import os + +protocol DataManaging { + func dataSet(for key: DataKey) async -> DataKey.DataSet + func store(_ dataSet: DataKey.DataSet, for key: DataKey) async +} + +public actor DataManager: DataManaging { + + private let embeddedDataProvider: EmbeddedDataProviding + private let fileStore: FileStoring + + public typealias FileNameProvider = (DataManager.StoredDataType) -> String + private nonisolated let fileNameProvider: FileNameProvider + + private var store: [StoredDataType: Any] = [:] + + public init(fileStore: FileStoring, embeddedDataProvider: EmbeddedDataProviding, fileNameProvider: @escaping FileNameProvider) { + self.embeddedDataProvider = embeddedDataProvider + self.fileStore = fileStore + self.fileNameProvider = fileNameProvider + } + + func dataSet(for key: DataKey) -> DataKey.DataSet { + let dataType = key.dataType + // return cached dataSet if available + if let data = store[key.dataType] as? DataKey.DataSet { + return data + } + + // read stored dataSet if it‘s newer than the embedded one + let dataSet = readStoredDataSet(for: key) ?? { + // no stored dataSet or the embedded one is newer + let embeddedRevision = embeddedDataProvider.revision(for: dataType) + let embeddedItems = embeddedDataProvider.loadDataSet(for: key) + return .init(revision: embeddedRevision, items: embeddedItems) + }() + + // cache + store[dataType] = dataSet + + return dataSet + } + + private func readStoredDataSet(for key: DataKey) -> DataKey.DataSet? { + let dataType = key.dataType + let fileName = fileNameProvider(dataType) + guard let data = fileStore.read(from: fileName) else { return nil } + + let storedDataSet: DataKey.DataSet + do { + storedDataSet = try JSONDecoder().decode(DataKey.DataSet.self, from: data) + } catch { + Logger.dataManager.error("Error decoding \(fileName): \(error.localizedDescription)") + return nil + } + + // compare to the embedded data revision + let embeddedDataRevision = embeddedDataProvider.revision(for: dataType) + guard storedDataSet.revision >= embeddedDataRevision else { + Logger.dataManager.error("Stored \(fileName) is outdated: revision: \(storedDataSet.revision), embedded revision: \(embeddedDataRevision).") + return nil + } + + return storedDataSet + } + + func store(_ dataSet: DataKey.DataSet, for key: DataKey) { + let dataType = key.dataType + let fileName = fileNameProvider(dataType) + self.store[dataType] = dataSet + + let data: Data + do { + data = try JSONEncoder().encode(dataSet) + } catch { + Logger.dataManager.error("Error encoding \(fileName): \(error.localizedDescription)") + assertionFailure("Failed to store data to \(fileName): \(error)") + return + } + + let success = fileStore.write(data: data, to: fileName) + assert(success) + } + +} diff --git a/Sources/MaliciousSiteProtection/Services/EmbeddedDataProvider.swift b/Sources/MaliciousSiteProtection/Services/EmbeddedDataProvider.swift new file mode 100644 index 000000000..942c6214a --- /dev/null +++ b/Sources/MaliciousSiteProtection/Services/EmbeddedDataProvider.swift @@ -0,0 +1,56 @@ +// +// EmbeddedDataProvider.swift +// +// Copyright © 2024 DuckDuckGo. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +import Foundation +import CryptoKit + +public protocol EmbeddedDataProviding { + func revision(for dataType: DataManager.StoredDataType) -> Int + func url(for dataType: DataManager.StoredDataType) -> URL + func hash(for dataType: DataManager.StoredDataType) -> String + + func data(withContentsOf url: URL) throws -> Data +} + +extension EmbeddedDataProviding { + + func loadDataSet(for key: DataKey) -> DataKey.EmbeddedDataSet { + let dataType = key.dataType + let url = url(for: dataType) + let data: Data + do { + data = try self.data(withContentsOf: url) +#if DEBUG + assert(data.sha256 == hash(for: dataType), "SHA mismatch for \(url.path)") +#endif + } catch { + fatalError("\(self): Could not load embedded data set at “\(url)”: \(error)") + } + do { + let result = try JSONDecoder().decode(DataKey.EmbeddedDataSet.self, from: data) + return result + } catch { + fatalError("\(self): Could not decode embedded data set at “\(url)”: \(error)") + } + } + + public func data(withContentsOf url: URL) throws -> Data { + try Data(contentsOf: url) + } + +} diff --git a/Sources/MaliciousSiteProtection/Services/FileStore.swift b/Sources/MaliciousSiteProtection/Services/FileStore.swift new file mode 100644 index 000000000..06418e6a2 --- /dev/null +++ b/Sources/MaliciousSiteProtection/Services/FileStore.swift @@ -0,0 +1,67 @@ +// +// FileStore.swift +// +// Copyright © 2024 DuckDuckGo. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +import Foundation +import os + +public protocol FileStoring { + @discardableResult func write(data: Data, to filename: String) -> Bool + func read(from filename: String) -> Data? +} + +public struct FileStore: FileStoring, CustomDebugStringConvertible { + private let dataStoreURL: URL + + public init(dataStoreURL: URL) { + self.dataStoreURL = dataStoreURL + createDirectoryIfNeeded() + } + + private func createDirectoryIfNeeded() { + do { + try FileManager.default.createDirectory(at: dataStoreURL, withIntermediateDirectories: true, attributes: nil) + } catch { + Logger.dataManager.error("Failed to create directory: \(error.localizedDescription)") + } + } + + public func write(data: Data, to filename: String) -> Bool { + let fileURL = dataStoreURL.appendingPathComponent(filename) + do { + try data.write(to: fileURL) + return true + } catch { + Logger.dataManager.error("Error writing to directory: \(error.localizedDescription)") + return false + } + } + + public func read(from filename: String) -> Data? { + let fileURL = dataStoreURL.appendingPathComponent(filename) + do { + return try Data(contentsOf: fileURL) + } catch { + Logger.dataManager.error("Error accessing application support directory: \(error)") + return nil + } + } + + public var debugDescription: String { + return "<\(type(of: self)) - \"\(dataStoreURL.path)\">" + } +} diff --git a/Sources/MaliciousSiteProtection/Services/UpdateManager.swift b/Sources/MaliciousSiteProtection/Services/UpdateManager.swift new file mode 100644 index 000000000..57394edbf --- /dev/null +++ b/Sources/MaliciousSiteProtection/Services/UpdateManager.swift @@ -0,0 +1,101 @@ +// +// UpdateManager.swift +// +// Copyright © 2024 DuckDuckGo. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +import Common +import Foundation +import Networking +import os + +protocol UpdateManaging { + func updateData(for key: some MaliciousSiteDataKey) async + + func startPeriodicUpdates() -> Task +} + +public struct UpdateManager: UpdateManaging { + + private let apiClient: APIClient.Mockable + private let dataManager: DataManaging + + public typealias UpdateIntervalProvider = (DataManager.StoredDataType) -> TimeInterval? + private let updateIntervalProvider: UpdateIntervalProvider + private let sleeper: Sleeper + + public init(apiEnvironment: APIClientEnvironment, service: APIService = DefaultAPIService(urlSession: .shared), dataManager: DataManager, updateIntervalProvider: @escaping UpdateIntervalProvider) { + self.init(apiClient: APIClient(environment: apiEnvironment, service: service), dataManager: dataManager, updateIntervalProvider: updateIntervalProvider) + } + + init(apiClient: APIClient.Mockable, dataManager: DataManaging, sleeper: Sleeper = .default, updateIntervalProvider: @escaping UpdateIntervalProvider) { + self.apiClient = apiClient + self.dataManager = dataManager + self.updateIntervalProvider = updateIntervalProvider + self.sleeper = sleeper + } + + func updateData(for key: DataKey) async { + // load currently stored data set + var dataSet = await dataManager.dataSet(for: key) + let oldRevision = dataSet.revision + + // get change set from current revision from API + let changeSet: APIClient.ChangeSetResponse + do { + let request = DataKey.DataSet.APIRequest(threatKind: key.threatKind, revision: oldRevision) + changeSet = try await apiClient.load(request) + } catch { + Logger.updateManager.error("error fetching filter set: \(error)") + return + } + guard !changeSet.isEmpty || changeSet.revision != dataSet.revision else { + Logger.updateManager.debug("no changes to filter set") + return + } + + // apply changes + dataSet.apply(changeSet) + + // store back + await self.dataManager.store(dataSet, for: key) + Logger.updateManager.debug("\(type(of: key)).\(key.threatKind) updated from rev.\(oldRevision) to rev.\(dataSet.revision)") + } + + public func startPeriodicUpdates() -> Task { + Task.detached { + // run update jobs in background for every data type + try await withThrowingTaskGroup(of: Never.self) { group in + for dataType in DataManager.StoredDataType.allCases { + // get update interval from provider + guard let updateInterval = updateIntervalProvider(dataType) else { continue } + guard updateInterval > 0 else { + assertionFailure("Update interval for \(dataType) must be positive") + continue + } + + group.addTask { + // run periodically until the parent task is cancelled + try await performPeriodicJob(interval: updateInterval, sleeper: sleeper) { + await self.updateData(for: dataType.dataKey) + } + } + } + for try await _ in group {} + } + } + } + +} diff --git a/Sources/Navigation/Extensions/WKErrorExtension.swift b/Sources/Navigation/Extensions/WKErrorExtension.swift index f1a5c238d..de750e766 100644 --- a/Sources/Navigation/Extensions/WKErrorExtension.swift +++ b/Sources/Navigation/Extensions/WKErrorExtension.swift @@ -33,6 +33,14 @@ extension WKError { code.rawValue == NSURLErrorCancelled && _nsError.domain == NSURLErrorDomain } + public var isServerCertificateUntrusted: Bool { + _nsError.isServerCertificateUntrusted + } +} +extension NSError { + public var isServerCertificateUntrusted: Bool { + code == NSURLErrorServerCertificateUntrusted && domain == NSURLErrorDomain + } } extension WKError { diff --git a/Sources/NetworkProtection/PacketTunnelProvider.swift b/Sources/NetworkProtection/PacketTunnelProvider.swift index a197c8a9a..7e82291d7 100644 --- a/Sources/NetworkProtection/PacketTunnelProvider.swift +++ b/Sources/NetworkProtection/PacketTunnelProvider.swift @@ -42,7 +42,6 @@ open class PacketTunnelProvider: NEPacketTunnelProvider { case rekeyAttempt(_ step: RekeyAttemptStep) case failureRecoveryAttempt(_ step: FailureRecoveryStep) case serverMigrationAttempt(_ step: ServerMigrationAttemptStep) - case malformedErrorDetected(_ error: Error) } public enum AttemptStep: CustomDebugStringConvertible { @@ -705,7 +704,6 @@ open class PacketTunnelProvider: NEPacketTunnelProvider { // Check that the error is valid and able to be re-thrown to the OS before shutting the tunnel down if let wrappedError = wrapped(error: error) { // Wait for the provider to complete its pixel request. - providerEvents.fire(.malformedErrorDetected(error)) try? await Task.sleep(interval: .seconds(2)) throw wrappedError } else { @@ -743,7 +741,6 @@ open class PacketTunnelProvider: NEPacketTunnelProvider { // Check that the error is valid and able to be re-thrown to the OS before shutting the tunnel down if let wrappedError = wrapped(error: error) { // Wait for the provider to complete its pixel request. - providerEvents.fire(.malformedErrorDetected(error)) try? await Task.sleep(interval: .seconds(2)) throw wrappedError } else { @@ -887,6 +884,7 @@ open class PacketTunnelProvider: NEPacketTunnelProvider { @MainActor private func stopTunnel() async throws { connectionStatus = .disconnecting + await stopMonitors() try await stopAdapter() } @@ -895,10 +893,6 @@ open class PacketTunnelProvider: NEPacketTunnelProvider { private func stopAdapter() async throws { try await withCheckedThrowingContinuation { (continuation: CheckedContinuation) in adapter.stop { [weak self] error in - if let self { - self.handleAdapterStopped() - } - if let error { self?.debugEvents.fire(error.networkProtectionError) @@ -1420,11 +1414,6 @@ open class PacketTunnelProvider: NEPacketTunnelProvider { try await startMonitors(testImmediately: testImmediately) } - @MainActor - public func handleAdapterStopped() { - connectionStatus = .disconnected - } - // MARK: - Monitors private func startTunnelFailureMonitor() async { diff --git a/Sources/Networking/README.md b/Sources/Networking/README.md index 83a2c5ce3..751ee63d8 100644 --- a/Sources/Networking/README.md +++ b/Sources/Networking/README.md @@ -19,7 +19,7 @@ let request = APIRequestV2(url: HTTPURLResponse.testUrl, responseConstraints: [.allowHTTPNotModified, .requireETagHeader, .requireUserAgent], - allowedQueryReservedCharacters: CharacterSet(charactersIn: ","))! + allowedQueryReservedCharacters: CharacterSet(charactersIn: ",")) let apiService = DefaultAPIService(urlSession: URLSession.shared) ``` @@ -55,12 +55,12 @@ The `MockPIService` implementing `APIService` can be found in `BSK/TestUtils` ``` let apiResponse = (Data(), HTTPURLResponse(url: HTTPURLResponse.testUrl, - statusCode: 200, - httpVersion: nil, - headerFields: nil)!) -let mockedAPIService = MockAPIService(decodableResponse: Result.failure(SomeError.testError), apiResponse: Result.success(apiResponse) ) + statusCode: 200, + httpVersion: nil, + headerFields: nil)!) +let mockedAPIService = MockAPIService(apiResponse: Result.success(apiResponse)) ``` ## v1 (Legacy) -Not to be used. All V1 public functions have been deprecated and maintained only for backward compatibility. \ No newline at end of file +Not to be used. All V1 public functions have been deprecated and maintained only for backward compatibility. diff --git a/Sources/Networking/v1/APIHeaders.swift b/Sources/Networking/v1/APIHeaders.swift index 6d7f0a4b0..a5786c949 100644 --- a/Sources/Networking/v1/APIHeaders.swift +++ b/Sources/Networking/v1/APIHeaders.swift @@ -25,7 +25,7 @@ public extension APIRequest { struct Headers { public typealias UserAgent = String - private static var userAgent: UserAgent? + public private(set) static var userAgent: UserAgent? public static func setUserAgent(_ userAgent: UserAgent) { self.userAgent = userAgent } diff --git a/Sources/Networking/v2/APIRequestV2.swift b/Sources/Networking/v2/APIRequestV2.swift index 07434de67..a61604861 100644 --- a/Sources/Networking/v2/APIRequestV2.swift +++ b/Sources/Networking/v2/APIRequestV2.swift @@ -16,12 +16,11 @@ // limitations under the License. // +import Common import Foundation public struct APIRequestV2: CustomDebugStringConvertible { - public typealias QueryItems = [String: String] - let timeoutInterval: TimeInterval let responseConstraints: [APIResponseConstraints]? public let urlRequest: URLRequest @@ -37,25 +36,25 @@ public struct APIRequestV2: CustomDebugStringConvertible { /// - cachePolicy: The request cache policy, default is `.useProtocolCachePolicy` /// - responseRequirements: The response requirements /// - allowedQueryReservedCharacters: The characters in this character set will not be URL encoded in the query parameters - public init?(url: URL, - method: HTTPRequestMethod = .get, - queryItems: QueryItems? = nil, - headers: APIRequestV2.HeadersV2? = APIRequestV2.HeadersV2(), - body: Data? = nil, - timeoutInterval: TimeInterval = 60.0, - cachePolicy: URLRequest.CachePolicy? = nil, - responseConstraints: [APIResponseConstraints]? = nil, - allowedQueryReservedCharacters: CharacterSet? = nil) { + public init( + url: URL, + method: HTTPRequestMethod = .get, + queryItems: QueryParams?, + headers: APIRequestV2.HeadersV2? = APIRequestV2.HeadersV2(), + body: Data? = nil, + timeoutInterval: TimeInterval = 60.0, + cachePolicy: URLRequest.CachePolicy? = nil, + responseConstraints: [APIResponseConstraints]? = nil, + allowedQueryReservedCharacters: CharacterSet? = nil + ) where QueryParams.Element == (key: String, value: String) { + self.timeoutInterval = timeoutInterval self.responseConstraints = responseConstraints - // Generate URL request - guard var urlComps = URLComponents(url: url, resolvingAgainstBaseURL: true) else { - return nil - } - urlComps.queryItems = queryItems?.toURLQueryItems(allowedReservedCharacters: allowedQueryReservedCharacters) - guard let finalURL = urlComps.url else { - return nil + let finalURL = if let queryItems { + url.appendingParameters(queryItems, allowedReservedCharacters: allowedQueryReservedCharacters) + } else { + url } var request = URLRequest(url: finalURL, timeoutInterval: timeoutInterval) request.allHTTPHeaderFields = headers?.httpHeaders @@ -67,6 +66,19 @@ public struct APIRequestV2: CustomDebugStringConvertible { self.urlRequest = request } + public init( + url: URL, + method: HTTPRequestMethod = .get, + headers: APIRequestV2.HeadersV2? = APIRequestV2.HeadersV2(), + body: Data? = nil, + timeoutInterval: TimeInterval = 60.0, + cachePolicy: URLRequest.CachePolicy? = nil, + responseConstraints: [APIResponseConstraints]? = nil, + allowedQueryReservedCharacters: CharacterSet? = nil + ) { + self.init(url: url, method: method, queryItems: [String: String]?.none, headers: headers, body: body, timeoutInterval: timeoutInterval, cachePolicy: cachePolicy, responseConstraints: responseConstraints, allowedQueryReservedCharacters: allowedQueryReservedCharacters) + } + public var debugDescription: String { """ APIRequestV2: diff --git a/Sources/Networking/v2/APIResponseV2.swift b/Sources/Networking/v2/APIResponseV2.swift index 1b178fd93..8987e377b 100644 --- a/Sources/Networking/v2/APIResponseV2.swift +++ b/Sources/Networking/v2/APIResponseV2.swift @@ -20,8 +20,13 @@ import Foundation import os.log public struct APIResponseV2 { - let data: Data? - let httpResponse: HTTPURLResponse + public let data: Data? + public let httpResponse: HTTPURLResponse + + public init(data: Data?, httpResponse: HTTPURLResponse) { + self.data = data + self.httpResponse = httpResponse + } } public extension APIResponseV2 { diff --git a/Sources/Onboarding/ContextualDaxDialogs/OnboardingSuggestionsViewModel.swift b/Sources/Onboarding/ContextualDaxDialogs/OnboardingSuggestionsViewModel.swift index d10fecd56..ffff91188 100644 --- a/Sources/Onboarding/ContextualDaxDialogs/OnboardingSuggestionsViewModel.swift +++ b/Sources/Onboarding/ContextualDaxDialogs/OnboardingSuggestionsViewModel.swift @@ -19,8 +19,8 @@ import Foundation public protocol OnboardingNavigationDelegate: AnyObject { - func searchFor(_ query: String) - func navigateTo(url: URL) + func searchFromOnboarding(for query: String) + func navigateFromOnboarding(to url: URL) } public protocol OnboardingSearchSuggestionsPixelReporting { @@ -52,7 +52,7 @@ public struct OnboardingSearchSuggestionsViewModel { public func listItemPressed(_ item: ContextualOnboardingListItem) { pixelReporter.trackSearchSuggetionOptionTapped() - delegate?.searchFor(item.title) + delegate?.searchFromOnboarding(for: item.title) } } @@ -82,6 +82,6 @@ public struct OnboardingSiteSuggestionsViewModel { public func listItemPressed(_ item: ContextualOnboardingListItem) { guard let url = URL(string: item.title) else { return } pixelReporter.trackSiteSuggetionOptionTapped() - delegate?.navigateTo(url: url) + delegate?.navigateFromOnboarding(to: url) } } diff --git a/Sources/PhishingDetection/Logger+PhishingDetection.swift b/Sources/PhishingDetection/Logger+PhishingDetection.swift deleted file mode 100644 index 96a606772..000000000 --- a/Sources/PhishingDetection/Logger+PhishingDetection.swift +++ /dev/null @@ -1,29 +0,0 @@ -// -// Logger+PhishingDetection.swift -// -// Copyright © 2023 DuckDuckGo. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// - -import Foundation -import os - -public extension Logger { - static var phishingDetection: Logger = { Logger(subsystem: "Phishing Detection", category: "") }() - static var phishingDetectionClient: Logger = { Logger(subsystem: "Phishing Detection", category: "APIClient") }() - static var phishingDetectionTasks: Logger = { Logger(subsystem: "Phishing Detection", category: "BackgroundActivities") }() - static var phishingDetectionDataProvider: Logger = { Logger(subsystem: "Phishing Detection", category: "DataProvider") }() - static var phishingDetectionDataStore: Logger = { Logger(subsystem: "Phishing Detection", category: "DataStore") }() - static var phishingDetectionUpdateManager: Logger = { Logger(subsystem: "Phishing Detection", category: "UpdateManager") }() -} diff --git a/Sources/PhishingDetection/PhishingDetectionClient.swift b/Sources/PhishingDetection/PhishingDetectionClient.swift deleted file mode 100644 index 942075b71..000000000 --- a/Sources/PhishingDetection/PhishingDetectionClient.swift +++ /dev/null @@ -1,177 +0,0 @@ -// -// PhishingDetectionClient.swift -// -// Copyright © 2023 DuckDuckGo. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// - -import Foundation -import Common -import os - -public struct HashPrefixResponse: Codable, Equatable { - public var insert: [String] - public var delete: [String] - public var revision: Int - public var replace: Bool - - public init(insert: [String], delete: [String], revision: Int, replace: Bool) { - self.insert = insert - self.delete = delete - self.revision = revision - self.replace = replace - } -} - -public struct FilterSetResponse: Codable, Equatable { - public var insert: [Filter] - public var delete: [Filter] - public var revision: Int - public var replace: Bool - - public init(insert: [Filter], delete: [Filter], revision: Int, replace: Bool) { - self.insert = insert - self.delete = delete - self.revision = revision - self.replace = replace - } -} - -public struct MatchResponse: Codable, Equatable { - public var matches: [Match] -} - -public protocol PhishingDetectionClientProtocol { - func getFilterSet(revision: Int) async -> FilterSetResponse - func getHashPrefixes(revision: Int) async -> HashPrefixResponse - func getMatches(hashPrefix: String) async -> [Match] -} - -public protocol URLSessionProtocol { - func data(for request: URLRequest) async throws -> (Data, URLResponse) -} - -extension URLSession: URLSessionProtocol {} - -extension URLSessionProtocol { - public static var defaultSession: URLSessionProtocol { - return URLSession.shared - } -} - -public class PhishingDetectionAPIClient: PhishingDetectionClientProtocol { - - public enum Environment { - case production - case staging - } - - enum Constants { - static let productionEndpoint = URL(string: "https://duckduckgo.com/api/protection/")! - static let stagingEndpoint = URL(string: "https://staging.duckduckgo.com/api/protection/")! - enum APIPath: String { - case filterSet - case hashPrefix - case matches - } - } - - private let endpointURL: URL - private let session: URLSessionProtocol! - private var headers: [String: String]? = [:] - - var filterSetURL: URL { - endpointURL.appendingPathComponent(Constants.APIPath.filterSet.rawValue) - } - - var hashPrefixURL: URL { - endpointURL.appendingPathComponent(Constants.APIPath.hashPrefix.rawValue) - } - - var matchesURL: URL { - endpointURL.appendingPathComponent(Constants.APIPath.matches.rawValue) - } - - public init(environment: Environment = .production, session: URLSessionProtocol = URLSession.defaultSession) { - switch environment { - case .production: - endpointURL = Constants.productionEndpoint - case .staging: - endpointURL = Constants.stagingEndpoint - } - self.session = session - } - - public func getFilterSet(revision: Int) async -> FilterSetResponse { - guard let url = createURL(for: .filterSet, revision: revision) else { - logDebug("🔸 Invalid filterSet revision URL: \(revision)") - return FilterSetResponse(insert: [], delete: [], revision: revision, replace: false) - } - return await fetch(url: url, responseType: FilterSetResponse.self) ?? FilterSetResponse(insert: [], delete: [], revision: revision, replace: false) - } - - public func getHashPrefixes(revision: Int) async -> HashPrefixResponse { - guard let url = createURL(for: .hashPrefix, revision: revision) else { - logDebug("🔸 Invalid hashPrefix revision URL: \(revision)") - return HashPrefixResponse(insert: [], delete: [], revision: revision, replace: false) - } - return await fetch(url: url, responseType: HashPrefixResponse.self) ?? HashPrefixResponse(insert: [], delete: [], revision: revision, replace: false) - } - - public func getMatches(hashPrefix: String) async -> [Match] { - let queryItems = [URLQueryItem(name: "hashPrefix", value: hashPrefix)] - guard let url = createURL(for: .matches, queryItems: queryItems) else { - logDebug("🔸 Invalid matches URL: \(hashPrefix)") - return [] - } - return await fetch(url: url, responseType: MatchResponse.self)?.matches ?? [] - } -} - -// MARK: Private Methods -extension PhishingDetectionAPIClient { - - private func logDebug(_ message: String) { - Logger.phishingDetectionClient.debug("\(message)") - } - - private func createURL(for path: Constants.APIPath, revision: Int? = nil, queryItems: [URLQueryItem]? = nil) -> URL? { - // Start with the base URL and append the path component - var urlComponents = URLComponents(url: endpointURL.appendingPathComponent(path.rawValue), resolvingAgainstBaseURL: true) - var items = queryItems ?? [] - if let revision = revision, revision > 0 { - items.append(URLQueryItem(name: "revision", value: String(revision))) - } - urlComponents?.queryItems = items.isEmpty ? nil : items - return urlComponents?.url - } - - private func fetch(url: URL, responseType: T.Type) async -> T? { - var request = URLRequest(url: url) - request.httpMethod = "GET" - request.allHTTPHeaderFields = headers - - do { - let (data, _) = try await session.data(for: request) - if let response = try? JSONDecoder().decode(responseType, from: data) { - return response - } else { - logDebug("🔸 Failed to decode response for \(String(describing: responseType)): \(data)") - } - } catch { - logDebug("🔴 Failed to load \(String(describing: responseType)) data: \(error)") - } - return nil - } -} diff --git a/Sources/PhishingDetection/PhishingDetectionDataActivities.swift b/Sources/PhishingDetection/PhishingDetectionDataActivities.swift deleted file mode 100644 index 3f195d75e..000000000 --- a/Sources/PhishingDetection/PhishingDetectionDataActivities.swift +++ /dev/null @@ -1,110 +0,0 @@ -// -// PhishingDetectionDataActivities.swift -// -// Copyright © 2024 DuckDuckGo. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// - -import Foundation -import Common -import os - -public protocol BackgroundActivityScheduling: Actor { - func start() - func stop() -} - -actor BackgroundActivityScheduler: BackgroundActivityScheduling { - - private var task: Task? - private var timer: Timer? - private let interval: TimeInterval - private let identifier: String - private let activity: () async -> Void - - init(interval: TimeInterval, identifier: String, activity: @escaping () async -> Void) { - self.interval = interval - self.identifier = identifier - self.activity = activity - } - - func start() { - stop() - task = Task { - let taskId = UUID().uuidString - while !Task.isCancelled { - await activity() - do { - Logger.phishingDetectionTasks.debug("🟢 \(self.identifier) task was executed in instance \(taskId)") - try await Task.sleep(nanoseconds: UInt64(interval * 1_000_000_000)) - } catch { - Logger.phishingDetectionTasks.error("🔴 Error \(self.identifier) task was cancelled before it could finish sleeping.") - break - } - } - } - } - - func stop() { - task?.cancel() - task = nil - } -} - -public protocol PhishingDetectionDataActivityHandling { - func start() - func stop() -} - -public class PhishingDetectionDataActivities: PhishingDetectionDataActivityHandling { - private var schedulers: [BackgroundActivityScheduler] - private var running: Bool = false - - var dataProvider: PhishingDetectionDataProviding - - public init(hashPrefixInterval: TimeInterval = 20 * 60, filterSetInterval: TimeInterval = 12 * 60 * 60, phishingDetectionDataProvider: PhishingDetectionDataProviding, updateManager: PhishingDetectionUpdateManaging) { - let hashPrefixScheduler = BackgroundActivityScheduler( - interval: hashPrefixInterval, - identifier: "hashPrefixes.update", - activity: { await updateManager.updateHashPrefixes() } - ) - let filterSetScheduler = BackgroundActivityScheduler( - interval: filterSetInterval, - identifier: "filterSet.update", - activity: { await updateManager.updateFilterSet() } - ) - self.schedulers = [hashPrefixScheduler, filterSetScheduler] - self.dataProvider = phishingDetectionDataProvider - } - - public func start() { - if !running { - Task { - for scheduler in schedulers { - await scheduler.start() - } - } - running = true - } - } - - public func stop() { - Task { - for scheduler in schedulers { - await scheduler.stop() - } - } - running = false - } -} diff --git a/Sources/PhishingDetection/PhishingDetectionDataProvider.swift b/Sources/PhishingDetection/PhishingDetectionDataProvider.swift deleted file mode 100644 index af1c87672..000000000 --- a/Sources/PhishingDetection/PhishingDetectionDataProvider.swift +++ /dev/null @@ -1,75 +0,0 @@ -// -// PhishingDetectionDataProvider.swift -// -// Copyright © 2024 DuckDuckGo. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// - -import Foundation -import CryptoKit -import Common -import os - -public protocol PhishingDetectionDataProviding { - var embeddedRevision: Int { get } - func loadEmbeddedFilterSet() -> Set - func loadEmbeddedHashPrefixes() -> Set -} - -public class PhishingDetectionDataProvider: PhishingDetectionDataProviding { - public private(set) var embeddedRevision: Int - var embeddedFilterSetURL: URL - var embeddedFilterSetDataSHA: String - var embeddedHashPrefixURL: URL - var embeddedHashPrefixDataSHA: String - - public init(revision: Int, filterSetURL: URL, filterSetDataSHA: String, hashPrefixURL: URL, hashPrefixDataSHA: String) { - embeddedFilterSetURL = filterSetURL - embeddedFilterSetDataSHA = filterSetDataSHA - embeddedHashPrefixURL = hashPrefixURL - embeddedHashPrefixDataSHA = hashPrefixDataSHA - embeddedRevision = revision - } - - private func loadData(from url: URL, expectedSHA: String) throws -> Data { - let data = try Data(contentsOf: url) - let sha256 = SHA256.hash(data: data) - let hashString = sha256.compactMap { String(format: "%02x", $0) }.joined() - - guard hashString == expectedSHA else { - throw NSError(domain: "PhishingDetectionDataProvider", code: 1001, userInfo: [NSLocalizedDescriptionKey: "SHA mismatch"]) - } - return data - } - - public func loadEmbeddedFilterSet() -> Set { - do { - let filterSetData = try loadData(from: embeddedFilterSetURL, expectedSHA: embeddedFilterSetDataSHA) - return try JSONDecoder().decode(Set.self, from: filterSetData) - } catch { - Logger.phishingDetectionDataProvider.error("🔴 Error: SHA mismatch for filterSet JSON file. Expected \(self.embeddedFilterSetDataSHA)") - return [] - } - } - - public func loadEmbeddedHashPrefixes() -> Set { - do { - let hashPrefixData = try loadData(from: embeddedHashPrefixURL, expectedSHA: embeddedHashPrefixDataSHA) - return try JSONDecoder().decode(Set.self, from: hashPrefixData) - } catch { - Logger.phishingDetectionDataProvider.error("🔴 Error: SHA mismatch for hashPrefixes JSON file. Expected \(self.embeddedHashPrefixDataSHA)") - return [] - } - } -} diff --git a/Sources/PhishingDetection/PhishingDetectionDataStore.swift b/Sources/PhishingDetection/PhishingDetectionDataStore.swift deleted file mode 100644 index f247f90b8..000000000 --- a/Sources/PhishingDetection/PhishingDetectionDataStore.swift +++ /dev/null @@ -1,266 +0,0 @@ -// -// PhishingDetectionDataStore.swift -// -// Copyright © 2024 DuckDuckGo. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// - -import Foundation -import Common -import os - -enum PhishingDetectionDataError: Error { - case empty -} - -public struct Filter: Codable, Hashable { - public var hashValue: String - public var regex: String - - enum CodingKeys: String, CodingKey { - case hashValue = "hash" - case regex - } - - public init(hashValue: String, regex: String) { - self.hashValue = hashValue - self.regex = regex - } -} - -public struct Match: Codable, Hashable { - var hostname: String - var url: String - var regex: String - var hash: String - - public init(hostname: String, url: String, regex: String, hash: String) { - self.hostname = hostname - self.url = url - self.regex = regex - self.hash = hash - } -} - -public protocol PhishingDetectionDataSaving { - var filterSet: Set { get } - var hashPrefixes: Set { get } - var currentRevision: Int { get } - func saveFilterSet(set: Set) - func saveHashPrefixes(set: Set) - func saveRevision(_ revision: Int) -} - -public class PhishingDetectionDataStore: PhishingDetectionDataSaving { - private lazy var _filterSet: Set = { - loadFilterSet() - }() - - private lazy var _hashPrefixes: Set = { - loadHashPrefix() - }() - - private lazy var _currentRevision: Int = { - loadRevision() - }() - - public private(set) var filterSet: Set { - get { _filterSet } - set { _filterSet = newValue } - } - public private(set) var hashPrefixes: Set { - get { _hashPrefixes } - set { _hashPrefixes = newValue } - } - public private(set) var currentRevision: Int { - get { _currentRevision } - set { _currentRevision = newValue } - } - - private let dataProvider: PhishingDetectionDataProviding - private let fileStorageManager: FileStorageManager - private let encoder = JSONEncoder() - private let revisionFilename = "revision.txt" - private let hashPrefixFilename = "hashPrefixes.json" - private let filterSetFilename = "filterSet.json" - - public init(dataProvider: PhishingDetectionDataProviding, - fileStorageManager: FileStorageManager? = nil) { - self.dataProvider = dataProvider - if let injectedFileStorageManager = fileStorageManager { - self.fileStorageManager = injectedFileStorageManager - } else { - self.fileStorageManager = PhishingFileStorageManager() - } - } - - private func writeHashPrefixes() { - let encoder = JSONEncoder() - do { - let hashPrefixesData = try encoder.encode(Array(hashPrefixes)) - fileStorageManager.write(data: hashPrefixesData, to: hashPrefixFilename) - } catch { - Logger.phishingDetectionDataStore.error("Error saving hash prefixes data: \(error.localizedDescription)") - } - } - - private func writeFilterSet() { - let encoder = JSONEncoder() - do { - let filterSetData = try encoder.encode(Array(filterSet)) - fileStorageManager.write(data: filterSetData, to: filterSetFilename) - } catch { - Logger.phishingDetectionDataStore.error("Error saving filter set data: \(error.localizedDescription)") - } - } - - private func writeRevision() { - let encoder = JSONEncoder() - do { - let revisionData = try encoder.encode(currentRevision) - fileStorageManager.write(data: revisionData, to: revisionFilename) - } catch { - Logger.phishingDetectionDataStore.error("Error saving revision data: \(error.localizedDescription)") - } - } - - private func loadHashPrefix() -> Set { - guard let data = fileStorageManager.read(from: hashPrefixFilename) else { - return dataProvider.loadEmbeddedHashPrefixes() - } - let decoder = JSONDecoder() - do { - if loadRevisionFromDisk() < dataProvider.embeddedRevision { - return dataProvider.loadEmbeddedHashPrefixes() - } - let onDiskHashPrefixes = Set(try decoder.decode(Set.self, from: data)) - return onDiskHashPrefixes - } catch { - Logger.phishingDetectionDataStore.error("Error decoding \(self.hashPrefixFilename): \(error.localizedDescription)") - return dataProvider.loadEmbeddedHashPrefixes() - } - } - - private func loadFilterSet() -> Set { - guard let data = fileStorageManager.read(from: filterSetFilename) else { - return dataProvider.loadEmbeddedFilterSet() - } - let decoder = JSONDecoder() - do { - if loadRevisionFromDisk() < dataProvider.embeddedRevision { - return dataProvider.loadEmbeddedFilterSet() - } - let onDiskFilterSet = Set(try decoder.decode(Set.self, from: data)) - return onDiskFilterSet - } catch { - Logger.phishingDetectionDataStore.error("Error decoding \(self.filterSetFilename): \(error.localizedDescription)") - return dataProvider.loadEmbeddedFilterSet() - } - } - - private func loadRevisionFromDisk() -> Int { - guard let data = fileStorageManager.read(from: revisionFilename) else { - return dataProvider.embeddedRevision - } - let decoder = JSONDecoder() - do { - return try decoder.decode(Int.self, from: data) - } catch { - Logger.phishingDetectionDataStore.error("Error decoding \(self.revisionFilename): \(error.localizedDescription)") - return dataProvider.embeddedRevision - } - } - - private func loadRevision() -> Int { - guard let data = fileStorageManager.read(from: revisionFilename) else { - return dataProvider.embeddedRevision - } - let decoder = JSONDecoder() - do { - let loadedRevision = try decoder.decode(Int.self, from: data) - if loadedRevision < dataProvider.embeddedRevision { - return dataProvider.embeddedRevision - } - return loadedRevision - } catch { - Logger.phishingDetectionDataStore.error("Error decoding \(self.revisionFilename): \(error.localizedDescription)") - return dataProvider.embeddedRevision - } - } -} - -extension PhishingDetectionDataStore { - public func saveFilterSet(set: Set) { - self.filterSet = set - writeFilterSet() - } - - public func saveHashPrefixes(set: Set) { - self.hashPrefixes = set - writeHashPrefixes() - } - - public func saveRevision(_ revision: Int) { - self.currentRevision = revision - writeRevision() - } -} - -public protocol FileStorageManager { - func write(data: Data, to filename: String) - func read(from filename: String) -> Data? -} - -final class PhishingFileStorageManager: FileStorageManager { - private let dataStoreURL: URL - - init() { - let dataStoreDirectory: URL - do { - dataStoreDirectory = try FileManager.default.url(for: .applicationSupportDirectory, in: .userDomainMask, appropriateFor: nil, create: true) - } catch { - Logger.phishingDetectionDataStore.error("Error accessing application support directory: \(error.localizedDescription)") - dataStoreDirectory = FileManager.default.temporaryDirectory - } - dataStoreURL = dataStoreDirectory.appendingPathComponent(Bundle.main.bundleIdentifier!, isDirectory: true) - createDirectoryIfNeeded() - } - - private func createDirectoryIfNeeded() { - do { - try FileManager.default.createDirectory(at: dataStoreURL, withIntermediateDirectories: true, attributes: nil) - } catch { - Logger.phishingDetectionDataStore.error("Failed to create directory: \(error.localizedDescription)") - } - } - - func write(data: Data, to filename: String) { - let fileURL = dataStoreURL.appendingPathComponent(filename) - do { - try data.write(to: fileURL) - } catch { - Logger.phishingDetectionDataStore.error("Error writing to directory: \(error.localizedDescription)") - } - } - - func read(from filename: String) -> Data? { - let fileURL = dataStoreURL.appendingPathComponent(filename) - do { - return try Data(contentsOf: fileURL) - } catch { - Logger.phishingDetectionDataStore.error("Error accessing application support directory: \(error)") - return nil - } - } -} diff --git a/Sources/PhishingDetection/PhishingDetectionUpdateManager.swift b/Sources/PhishingDetection/PhishingDetectionUpdateManager.swift deleted file mode 100644 index b811082e3..000000000 --- a/Sources/PhishingDetection/PhishingDetectionUpdateManager.swift +++ /dev/null @@ -1,83 +0,0 @@ -// -// PhishingDetectionUpdateManager.swift -// -// Copyright © 2023 DuckDuckGo. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// - -import Foundation -import Common -import os - -public protocol PhishingDetectionUpdateManaging { - func updateFilterSet() async - func updateHashPrefixes() async -} - -public class PhishingDetectionUpdateManager: PhishingDetectionUpdateManaging { - var apiClient: PhishingDetectionClientProtocol - var dataStore: PhishingDetectionDataSaving - - public init(client: PhishingDetectionClientProtocol, dataStore: PhishingDetectionDataSaving) { - self.apiClient = client - self.dataStore = dataStore - } - - private func updateSet( - currentSet: Set, - insert: [T], - delete: [T], - replace: Bool, - saveSet: (Set) -> Void - ) { - var newSet = currentSet - - if replace { - newSet = Set(insert) - } else { - newSet.formUnion(insert) - newSet.subtract(delete) - } - - saveSet(newSet) - } - - public func updateFilterSet() async { - let response = await apiClient.getFilterSet(revision: dataStore.currentRevision) - updateSet( - currentSet: dataStore.filterSet, - insert: response.insert, - delete: response.delete, - replace: response.replace - ) { newSet in - self.dataStore.saveFilterSet(set: newSet) - } - dataStore.saveRevision(response.revision) - Logger.phishingDetectionUpdateManager.debug("filterSet updated to revision \(self.dataStore.currentRevision)") - } - - public func updateHashPrefixes() async { - let response = await apiClient.getHashPrefixes(revision: dataStore.currentRevision) - updateSet( - currentSet: dataStore.hashPrefixes, - insert: response.insert, - delete: response.delete, - replace: response.replace - ) { newSet in - self.dataStore.saveHashPrefixes(set: newSet) - } - dataStore.saveRevision(response.revision) - Logger.phishingDetectionUpdateManager.debug("hashPrefixes updated to revision \(self.dataStore.currentRevision)") - } -} diff --git a/Sources/PhishingDetection/PhishingDetector.swift b/Sources/PhishingDetection/PhishingDetector.swift deleted file mode 100644 index 3ccbe9b7e..000000000 --- a/Sources/PhishingDetection/PhishingDetector.swift +++ /dev/null @@ -1,130 +0,0 @@ -// -// PhishingDetector.swift -// -// Copyright © 2023 DuckDuckGo. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// - -import Foundation -import CryptoKit -import Common -import WebKit - -public enum PhishingDetectionError: CustomNSError { - case detected - - public static let errorDomain: String = "PhishingDetectionError" - - public var errorCode: Int { - switch self { - case .detected: - return 1331 - } - } - - public var errorUserInfo: [String: Any] { - switch self { - case .detected: - return [NSLocalizedDescriptionKey: "Phishing detected"] - } - } - - public var rawValue: Int { - return self.errorCode - } -} - -public protocol PhishingDetecting { - func isMalicious(url: URL) async -> Bool -} - -public class PhishingDetector: PhishingDetecting { - let hashPrefixStoreLength: Int = 8 - let hashPrefixParamLength: Int = 4 - let apiClient: PhishingDetectionClientProtocol - let dataStore: PhishingDetectionDataSaving - let eventMapping: EventMapping - - public init(apiClient: PhishingDetectionClientProtocol, dataStore: PhishingDetectionDataSaving, eventMapping: EventMapping) { - self.apiClient = apiClient - self.dataStore = dataStore - self.eventMapping = eventMapping - } - - private func getMatches(hashPrefix: String) async -> Set { - return Set(await apiClient.getMatches(hashPrefix: hashPrefix)) - } - - private func inFilterSet(hash: String) -> Set { - return Set(dataStore.filterSet.filter { $0.hashValue == hash }) - } - - private func matchesUrl(hash: String, regexPattern: String, url: URL, hostnameHash: String) -> Bool { - if hash == hostnameHash, - let regex = try? NSRegularExpression(pattern: regexPattern, options: []) { - let urlString = url.absoluteString - let range = NSRange(location: 0, length: urlString.utf16.count) - return regex.firstMatch(in: urlString, options: [], range: range) != nil - } - return false - } - - private func generateHashPrefix(for canonicalHost: String, length: Int) -> String { - let hostnameHash = SHA256.hash(data: Data(canonicalHost.utf8)).map { String(format: "%02hhx", $0) }.joined() - return String(hostnameHash.prefix(length)) - } - - private func fetchMatches(hashPrefix: String) async -> [Match] { - return await apiClient.getMatches(hashPrefix: hashPrefix) - } - - private func checkLocalFilters(canonicalHost: String, canonicalUrl: URL) -> Bool { - let hostnameHash = generateHashPrefix(for: canonicalHost, length: Int.max) - let filterHit = inFilterSet(hash: hostnameHash) - for filter in filterHit where matchesUrl(hash: filter.hashValue, regexPattern: filter.regex, url: canonicalUrl, hostnameHash: hostnameHash) { - eventMapping.fire(PhishingDetectionEvents.errorPageShown(clientSideHit: true)) - return true - } - return false - } - - private func checkApiMatches(canonicalHost: String, canonicalUrl: URL) async -> Bool { - let hashPrefixParam = generateHashPrefix(for: canonicalHost, length: hashPrefixParamLength) - let matches = await fetchMatches(hashPrefix: hashPrefixParam) - let hostnameHash = generateHashPrefix(for: canonicalHost, length: Int.max) - for match in matches where matchesUrl(hash: match.hash, regexPattern: match.regex, url: canonicalUrl, hostnameHash: hostnameHash) { - eventMapping.fire(PhishingDetectionEvents.errorPageShown(clientSideHit: false)) - return true - } - return false - } - - public func isMalicious(url: URL) async -> Bool { - guard let canonicalHost = url.canonicalHost(), let canonicalUrl = url.canonicalURL() else { return false } - - let hashPrefix = generateHashPrefix(for: canonicalHost, length: hashPrefixStoreLength) - if dataStore.hashPrefixes.contains(hashPrefix) { - // Check local filterSet first - if checkLocalFilters(canonicalHost: canonicalHost, canonicalUrl: canonicalUrl) { - return true - } - // If nothing found, hit the API to get matches - if await checkApiMatches(canonicalHost: canonicalHost, canonicalUrl: canonicalUrl) { - return true - } - } - - return false - } -} diff --git a/Sources/PrivacyDashboard/PrivacyDashboardController.swift b/Sources/PrivacyDashboard/PrivacyDashboardController.swift index 8093a02d2..988c4cd20 100644 --- a/Sources/PrivacyDashboard/PrivacyDashboardController.swift +++ b/Sources/PrivacyDashboard/PrivacyDashboardController.swift @@ -16,12 +16,13 @@ // limitations under the License. // -import Foundation -import WebKit -import Combine -import PrivacyDashboardResources import BrowserServicesKit +import Combine import Common +import Foundation +import MaliciousSiteProtection +import PrivacyDashboardResources +import WebKit public enum PrivacyDashboardOpenSettingsTarget: String { @@ -205,7 +206,7 @@ extension PrivacyDashboardController: WKNavigationDelegate { subscribeToServerTrust() subscribeToConsentManaged() subscribeToAllowedPermissions() - subscribeToIsPhishing() + subscribeToMaliciousSiteThreatKind() } private func subscribeToTheme() { @@ -259,12 +260,12 @@ extension PrivacyDashboardController: WKNavigationDelegate { .store(in: &cancellables) } - private func subscribeToIsPhishing() { - privacyInfo?.$isPhishing + private func subscribeToMaliciousSiteThreatKind() { + privacyInfo?.$malicousSiteThreatKind .receive(on: DispatchQueue.main ) - .sink(receiveValue: { [weak self] isPhishing in - guard let self = self, let webView = self.webView else { return } - script.setIsPhishing(isPhishing, webView: webView) + .sink(receiveValue: { [weak self] detectedThreatKind in + guard let self, let webView else { return } + script.setMaliciousSiteDetectedThreatKind(detectedThreatKind, webView: webView) }) .store(in: &cancellables) } diff --git a/Sources/PrivacyDashboard/PrivacyDashboardUserScript.swift b/Sources/PrivacyDashboard/PrivacyDashboardUserScript.swift index 23fc5e738..801cdd81c 100644 --- a/Sources/PrivacyDashboard/PrivacyDashboardUserScript.swift +++ b/Sources/PrivacyDashboard/PrivacyDashboardUserScript.swift @@ -16,12 +16,13 @@ // limitations under the License. // +import BrowserServicesKit +import Common import Foundation -import WebKit +import MaliciousSiteProtection import TrackerRadarKit import UserScript -import Common -import BrowserServicesKit +import WebKit @MainActor protocol PrivacyDashboardUserScriptDelegate: AnyObject { @@ -425,8 +426,8 @@ final class PrivacyDashboardUserScript: NSObject, StaticUserScript { evaluate(js: "window.onChangeCertificateData(\(certificateDataJson))", in: webView) } - func setIsPhishing(_ isPhishing: Bool, webView: WKWebView) { - let phishingStatus = ["phishingStatus": isPhishing] + func setMaliciousSiteDetectedThreatKind(_ detectedThreatKind: MaliciousSiteProtection.ThreatKind?, webView: WKWebView) { + let phishingStatus = ["phishingStatus": detectedThreatKind == .phishing] guard let phishingStatusJson = try? JSONEncoder().encode(phishingStatus).utf8String() else { assertionFailure("Can't encode phishingStatus into JSON") return diff --git a/Sources/PrivacyDashboard/PrivacyInfo.swift b/Sources/PrivacyDashboard/PrivacyInfo.swift index b9db906fc..3eaabc185 100644 --- a/Sources/PrivacyDashboard/PrivacyInfo.swift +++ b/Sources/PrivacyDashboard/PrivacyInfo.swift @@ -16,9 +16,10 @@ // limitations under the License. // +import Common import Foundation +import MaliciousSiteProtection import TrackerRadarKit -import Common public protocol SecurityTrust { } extension SecTrust: SecurityTrust {} @@ -33,15 +34,15 @@ public final class PrivacyInfo { @Published public var serverTrust: SecurityTrust? @Published public var connectionUpgradedTo: URL? @Published public var cookieConsentManaged: CookieConsentInfo? - @Published public var isPhishing: Bool + @Published public var malicousSiteThreatKind: MaliciousSiteProtection.ThreatKind? @Published public var isSpecialErrorPageVisible: Bool = false @Published public var shouldCheckServerTrust: Bool - public init(url: URL, parentEntity: Entity?, protectionStatus: ProtectionStatus, isPhishing: Bool = false, shouldCheckServerTrust: Bool = false) { + public init(url: URL, parentEntity: Entity?, protectionStatus: ProtectionStatus, malicousSiteThreatKind: MaliciousSiteProtection.ThreatKind? = .none, shouldCheckServerTrust: Bool = false) { self.url = url self.parentEntity = parentEntity self.protectionStatus = protectionStatus - self.isPhishing = isPhishing + self.malicousSiteThreatKind = malicousSiteThreatKind self.shouldCheckServerTrust = shouldCheckServerTrust trackerInfo = TrackerInfo() diff --git a/Sources/PrivacyStats/Logger+PrivacyStats.swift b/Sources/PrivacyStats/Logger+PrivacyStats.swift new file mode 100644 index 000000000..e8649a6af --- /dev/null +++ b/Sources/PrivacyStats/Logger+PrivacyStats.swift @@ -0,0 +1,24 @@ +// +// Logger+PrivacyStats.swift +// +// Copyright © 2024 DuckDuckGo. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +import Foundation +import os.log + +public extension Logger { + static var privacyStats = { Logger(subsystem: "Privacy Stats", category: "") }() +} diff --git a/Sources/PrivacyStats/PrivacyStats.swift b/Sources/PrivacyStats/PrivacyStats.swift new file mode 100644 index 000000000..c298f60fc --- /dev/null +++ b/Sources/PrivacyStats/PrivacyStats.swift @@ -0,0 +1,249 @@ +// +// PrivacyStats.swift +// +// Copyright © 2024 DuckDuckGo. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +import Combine +import Common +import CoreData +import Foundation +import os.log +import Persistence +import TrackerRadarKit + +/** + * Errors that may be reported by `PrivacyStats`. + */ +public enum PrivacyStatsError: CustomNSError { + case failedToFetchPrivacyStatsSummary(Error) + case failedToStorePrivacyStats(Error) + case failedToLoadCurrentPrivacyStats(Error) + + public static let errorDomain: String = "PrivacyStatsError" + + public var errorCode: Int { + switch self { + case .failedToFetchPrivacyStatsSummary: + return 1 + case .failedToStorePrivacyStats: + return 2 + case .failedToLoadCurrentPrivacyStats: + return 3 + } + } + + public var underlyingError: Error { + switch self { + case .failedToFetchPrivacyStatsSummary(let error), + .failedToStorePrivacyStats(let error), + .failedToLoadCurrentPrivacyStats(let error): + return error + } + } +} + +/** + * This protocol describes database provider consumed by `PrivacyStats`. + */ +public protocol PrivacyStatsDatabaseProviding { + func initializeDatabase() -> CoreDataDatabase +} + +/** + * This protocol describes `PrivacyStats` interface. + */ +public protocol PrivacyStatsCollecting { + + /** + * Record a tracker for a given `companyName`. + * + * `PrivacyStats` implementation calls the `CurrentPack` actor under the hood, + * and as such it can safely be called on multiple threads concurrently. + */ + func recordBlockedTracker(_ name: String) async + + /** + * Publisher emitting values whenever updated privacy stats were persisted to disk. + */ + var statsUpdatePublisher: AnyPublisher { get } + + /** + * This function fetches privacy stats in a dictionary format + * with keys being company names and values being total number + * of tracking attempts blocked in past 7 days. + */ + func fetchPrivacyStats() async -> [String: Int64] + + /** + * This function clears all blocked tracker stats from the database. + */ + func clearPrivacyStats() async + + /** + * This function saves all pending changes to the persistent storage. + * + * It should only be used in response to app termination because otherwise + * the `PrivacyStats` object schedules persisting internally. + */ + func handleAppTermination() async +} + +public final class PrivacyStats: PrivacyStatsCollecting { + + public static let bundle = Bundle.module + + public let statsUpdatePublisher: AnyPublisher + + private let db: CoreDataDatabase + private let context: NSManagedObjectContext + private let currentPack: CurrentPack + private let statsUpdateSubject = PassthroughSubject() + private var cancellables: Set = [] + + private let errorEvents: EventMapping? + + public init(databaseProvider: PrivacyStatsDatabaseProviding, errorEvents: EventMapping? = nil) { + self.db = databaseProvider.initializeDatabase() + self.context = db.makeContext(concurrencyType: .privateQueueConcurrencyType, name: "PrivacyStats") + self.errorEvents = errorEvents + self.currentPack = .init(pack: Self.initializeCurrentPack(in: context, errorEvents: errorEvents)) + statsUpdatePublisher = statsUpdateSubject.eraseToAnyPublisher() + + currentPack.commitChangesPublisher + .sink { [weak self] pack in + Task { + await self?.commitChanges(pack) + } + } + .store(in: &cancellables) + } + + public func recordBlockedTracker(_ companyName: String) async { + await currentPack.recordBlockedTracker(companyName) + } + + public func fetchPrivacyStats() async -> [String: Int64] { + return await withCheckedContinuation { continuation in + context.perform { [weak self] in + guard let self else { + continuation.resume(returning: [:]) + return + } + do { + let stats = try PrivacyStatsUtils.load7DayStats(in: context) + continuation.resume(returning: stats) + } catch { + errorEvents?.fire(.failedToFetchPrivacyStatsSummary(error)) + continuation.resume(returning: [:]) + } + } + } + } + + public func clearPrivacyStats() async { + await withCheckedContinuation { continuation in + context.perform { [weak self] in + guard let self else { + continuation.resume() + return + } + do { + try PrivacyStatsUtils.deleteAllStats(in: context) + Logger.privacyStats.debug("Deleted outdated entries") + } catch { + Logger.privacyStats.error("Save error: \(error)") + errorEvents?.fire(.failedToFetchPrivacyStatsSummary(error)) + } + continuation.resume() + } + } + await currentPack.resetPack() + statsUpdateSubject.send() + } + + public func handleAppTermination() async { + await commitChanges(currentPack.pack) + } + + // MARK: - Private + + private func commitChanges(_ pack: PrivacyStatsPack) async { + await withCheckedContinuation { continuation in + context.perform { [weak self] in + guard let self else { + continuation.resume() + return + } + + // Check if the pack we're currently storing is from a previous day. + let isCurrentDayPack = pack.timestamp == Date.currentPrivacyStatsPackTimestamp + + do { + let statsObjects = try PrivacyStatsUtils.fetchOrInsertCurrentStats(for: Set(pack.trackers.keys), in: context) + statsObjects.forEach { stats in + if let count = pack.trackers[stats.companyName] { + stats.count = count + } + } + + guard context.hasChanges else { + continuation.resume() + return + } + + try context.save() + Logger.privacyStats.debug("Saved stats \(pack.timestamp) \(pack.trackers)") + + if isCurrentDayPack { + // Only emit update event when saving current-day pack. For previous-day pack, + // a follow-up commit event will come and we'll emit the update then. + statsUpdateSubject.send() + } else { + // When storing a pack from a previous day, we may have outdated packs, so delete them as needed. + try PrivacyStatsUtils.deleteOutdatedPacks(in: context) + } + } catch { + Logger.privacyStats.error("Save error: \(error)") + errorEvents?.fire(.failedToStorePrivacyStats(error)) + } + continuation.resume() + } + } + } + + /** + * This function is only called in the initializer. It performs a blocking call to the database + * to spare us the hassle of declaring the initializer async or spawning tasks from within the + * initializer without being able to await them, thus making testing trickier. + */ + private static func initializeCurrentPack(in context: NSManagedObjectContext, errorEvents: EventMapping?) -> PrivacyStatsPack { + var pack: PrivacyStatsPack? + context.performAndWait { + let timestamp = Date.currentPrivacyStatsPackTimestamp + do { + let currentDayStats = try PrivacyStatsUtils.loadCurrentDayStats(in: context) + Logger.privacyStats.debug("Loaded stats \(timestamp) \(currentDayStats)") + pack = PrivacyStatsPack(timestamp: timestamp, trackers: currentDayStats) + + try PrivacyStatsUtils.deleteOutdatedPacks(in: context) + } catch { + Logger.privacyStats.error("Failed to load current stats: \(error)") + errorEvents?.fire(.failedToLoadCurrentPrivacyStats(error)) + } + } + return pack ?? PrivacyStatsPack(timestamp: Date.currentPrivacyStatsPackTimestamp) + } +} diff --git a/Sources/PrivacyStats/PrivacyStats.xcdatamodeld/.xccurrentversion b/Sources/PrivacyStats/PrivacyStats.xcdatamodeld/.xccurrentversion new file mode 100644 index 000000000..1a19d1654 --- /dev/null +++ b/Sources/PrivacyStats/PrivacyStats.xcdatamodeld/.xccurrentversion @@ -0,0 +1,8 @@ + + + + + _XCCurrentVersionName + PrivacyStats.xcdatamodel + + diff --git a/Sources/PrivacyStats/PrivacyStats.xcdatamodeld/PrivacyStats.xcdatamodel/contents b/Sources/PrivacyStats/PrivacyStats.xcdatamodeld/PrivacyStats.xcdatamodel/contents new file mode 100644 index 000000000..39798857a --- /dev/null +++ b/Sources/PrivacyStats/PrivacyStats.xcdatamodeld/PrivacyStats.xcdatamodel/contents @@ -0,0 +1,18 @@ + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/Sources/PrivacyStats/internal/CurrentPack.swift b/Sources/PrivacyStats/internal/CurrentPack.swift new file mode 100644 index 000000000..fffc6b090 --- /dev/null +++ b/Sources/PrivacyStats/internal/CurrentPack.swift @@ -0,0 +1,116 @@ +// +// CurrentPack.swift +// +// Copyright © 2024 DuckDuckGo. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +import Combine +import Foundation +import os.log + +/** + * This actor provides thread-safe access to an instance of `PrivacyStatsPack`. + * + * It's used by `PrivacyStats` class to record blocked trackers that can possibly + * come from multiple open tabs (web views) at the same time. + */ +actor CurrentPack { + /** + * Current stats pack. + */ + private(set) var pack: PrivacyStatsPack + + /** + * Publisher that fires events whenever tracker stats are ready to be persisted to disk. + * + * This happens after recording new blocked tracker, when no new tracker has been recorded for 1s. + */ + nonisolated private(set) lazy var commitChangesPublisher: AnyPublisher = commitChangesSubject.eraseToAnyPublisher() + + nonisolated private let commitChangesSubject = PassthroughSubject() + private var commitTask: Task? + private var commitDebounce: UInt64 + + /// The `commitDebounce` parameter should only be modified in unit tests. + init(pack: PrivacyStatsPack, commitDebounce: UInt64 = 1_000_000_000) { + self.pack = pack + self.commitDebounce = commitDebounce + } + + deinit { + commitTask?.cancel() + } + + /** + * This function is used when clearing app data, to clear any stats cached in memory. + * + * It sets a new empty pack with the current timestamp. + */ + func resetPack() { + resetStats(andSet: Date.currentPrivacyStatsPackTimestamp) + } + + /** + * This function increments trackers count for a given company name. + * + * Updates are kept in memory and scheduled for saving to persistent storage with 1s debounce. + * This function also detects when the current pack becomes outdated (which happens + * when current timestamp's day becomes greater than pack's timestamp's day), in which + * case current pack is scheduled for persisting on disk and a new empty pack is + * created for the new timestamp. + */ + func recordBlockedTracker(_ companyName: String) { + + let currentTimestamp = Date.currentPrivacyStatsPackTimestamp + if currentTimestamp != pack.timestamp { + Logger.privacyStats.debug("New timestamp detected, storing trackers state and creating new pack") + notifyChanges(for: pack, immediately: true) + resetStats(andSet: currentTimestamp) + } + + let count = pack.trackers[companyName] ?? 0 + pack.trackers[companyName] = count + 1 + + notifyChanges(for: pack, immediately: false) + } + + private func notifyChanges(for pack: PrivacyStatsPack, immediately shouldPublishImmediately: Bool) { + commitTask?.cancel() + + if shouldPublishImmediately { + + commitChangesSubject.send(pack) + + } else { + + commitTask = Task { + do { + // Note that this doesn't always sleep for the full debounce time, but the sleep is interrupted + // as soon as the task gets cancelled. + try await Task.sleep(nanoseconds: commitDebounce) + + Logger.privacyStats.debug("Storing trackers state") + commitChangesSubject.send(pack) + } catch { + // Commit task got cancelled + } + } + } + } + + private func resetStats(andSet newTimestamp: Date) { + pack = PrivacyStatsPack(timestamp: newTimestamp, trackers: [:]) + } +} diff --git a/Sources/PrivacyStats/internal/DailyBlockedTrackersEntity.swift b/Sources/PrivacyStats/internal/DailyBlockedTrackersEntity.swift new file mode 100644 index 000000000..728a5943c --- /dev/null +++ b/Sources/PrivacyStats/internal/DailyBlockedTrackersEntity.swift @@ -0,0 +1,55 @@ +// +// DailyBlockedTrackersEntity.swift +// +// Copyright © 2024 DuckDuckGo. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +import Common +import CoreData + +@objc(DailyBlockedTrackersEntity) +final class DailyBlockedTrackersEntity: NSManagedObject { + enum Const { + static let entityName = "DailyBlockedTrackersEntity" + } + + @nonobjc class func fetchRequest() -> NSFetchRequest { + NSFetchRequest(entityName: Const.entityName) + } + + class func entity(in context: NSManagedObjectContext) -> NSEntityDescription { + NSEntityDescription.entity(forEntityName: Const.entityName, in: context)! + } + + @NSManaged var companyName: String + @NSManaged var count: Int64 + @NSManaged var timestamp: Date + + private override init(entity: NSEntityDescription, insertInto context: NSManagedObjectContext?) { + super.init(entity: entity, insertInto: context) + } + + private convenience init(context moc: NSManagedObjectContext) { + self.init(entity: DailyBlockedTrackersEntity.entity(in: moc), insertInto: moc) + } + + static func make(timestamp: Date = Date(), companyName: String, count: Int64 = 0, context: NSManagedObjectContext) -> DailyBlockedTrackersEntity { + let object = DailyBlockedTrackersEntity(context: context) + object.timestamp = timestamp.privacyStatsPackTimestamp + object.companyName = companyName + object.count = count + return object + } +} diff --git a/Sources/PrivacyStats/internal/Date+PrivacyStats.swift b/Sources/PrivacyStats/internal/Date+PrivacyStats.swift new file mode 100644 index 000000000..f56072d75 --- /dev/null +++ b/Sources/PrivacyStats/internal/Date+PrivacyStats.swift @@ -0,0 +1,51 @@ +// +// Date+PrivacyStats.swift +// +// Copyright © 2024 DuckDuckGo. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +import Common +import Foundation + +extension Date { + + /** + * Returns privacy stats pack timestamp for the current date. + * + * See `privacyStatsPackTimestamp`. + */ + static var currentPrivacyStatsPackTimestamp: Date { + Date().privacyStatsPackTimestamp + } + + /** + * Returns a valid timestamp for `DailyBlockedTrackersEntity` instance matching the sender. + * + * Blocked trackers are packed by day so the timestap of the pack must be the exact start of a day. + */ + var privacyStatsPackTimestamp: Date { + startOfDay + } + + /** + * Returns the oldest valid timestamp for `DailyBlockedTrackersEntity` instance matching the sender. + * + * Privacy Stats only keeps track of 7 days worth of tracking history, so the oldest timestamp is + * beginning of the day 6 days ago. + */ + var privacyStatsOldestPackTimestamp: Date { + privacyStatsPackTimestamp.daysAgo(6) + } +} diff --git a/Sources/PrivacyStats/internal/PrivacyStatsPack.swift b/Sources/PrivacyStats/internal/PrivacyStatsPack.swift new file mode 100644 index 000000000..3c4c8a04b --- /dev/null +++ b/Sources/PrivacyStats/internal/PrivacyStatsPack.swift @@ -0,0 +1,32 @@ +// +// PrivacyStatsPack.swift +// +// Copyright © 2024 DuckDuckGo. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +import Foundation + +/** + * This struct keeps track of the summary of blocked trackers for a single unit of time (1 day). + */ +struct PrivacyStatsPack: Equatable { + let timestamp: Date + var trackers: [String: Int64] + + init(timestamp: Date, trackers: [String: Int64] = [:]) { + self.timestamp = timestamp + self.trackers = trackers + } +} diff --git a/Sources/PrivacyStats/internal/PrivacyStatsUtils.swift b/Sources/PrivacyStats/internal/PrivacyStatsUtils.swift new file mode 100644 index 000000000..33b93d869 --- /dev/null +++ b/Sources/PrivacyStats/internal/PrivacyStatsUtils.swift @@ -0,0 +1,123 @@ +// +// PrivacyStatsUtils.swift +// +// Copyright © 2024 DuckDuckGo. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +import Common +import CoreData +import Foundation +import Persistence + +final class PrivacyStatsUtils { + + /** + * Returns objects corresponding to current stats for companies specified by `companyNames`. + * + * If an object doesn't exist (no trackers for a given company were reported on a given day) + * then a new object for that company is inserted into the context and returned. + * If a user opens the app for the first time on a given day, the database will not contain + * any records for that day and this function will only insert new objects. + * + * > Note: `current stats` refer to stats objects that are active on a given day, i.e. their + * timestamp's day matches current day. + */ + static func fetchOrInsertCurrentStats(for companyNames: Set, in context: NSManagedObjectContext) throws -> [DailyBlockedTrackersEntity] { + let timestamp = Date.currentPrivacyStatsPackTimestamp + + let request = DailyBlockedTrackersEntity.fetchRequest() + request.predicate = NSPredicate(format: "%K == %@ AND %K in %@", + #keyPath(DailyBlockedTrackersEntity.timestamp), timestamp as NSDate, + #keyPath(DailyBlockedTrackersEntity.companyName), companyNames) + request.returnsObjectsAsFaults = false + + var statsObjects = try context.fetch(request) + let missingCompanyNames = companyNames.subtracting(statsObjects.map(\.companyName)) + + for companyName in missingCompanyNames { + statsObjects.append(DailyBlockedTrackersEntity.make(timestamp: timestamp, companyName: companyName, context: context)) + } + return statsObjects + } + + /** + * Returns a dictionary representation of blocked trackers counts grouped by company name for the current day. + */ + static func loadCurrentDayStats(in context: NSManagedObjectContext) throws -> [String: Int64] { + let startDate = Date.currentPrivacyStatsPackTimestamp + return try loadBlockedTrackersStats(since: startDate, in: context) + } + + /** + * Returns a dictionary representation of blocked trackers counts grouped by company name for past 7 days. + */ + static func load7DayStats(in context: NSManagedObjectContext) throws -> [String: Int64] { + let startDate = Date().privacyStatsOldestPackTimestamp + return try loadBlockedTrackersStats(since: startDate, in: context) + } + + private static func loadBlockedTrackersStats(since startDate: Date, in context: NSManagedObjectContext) throws -> [String: Int64] { + let request = NSFetchRequest(entityName: DailyBlockedTrackersEntity.Const.entityName) + request.predicate = NSPredicate(format: "%K >= %@", #keyPath(DailyBlockedTrackersEntity.timestamp), startDate as NSDate) + + let companyNameKey = #keyPath(DailyBlockedTrackersEntity.companyName) + + // Expression description for the sum of count + let countExpression = NSExpression(forKeyPath: #keyPath(DailyBlockedTrackersEntity.count)) + let sumExpression = NSExpression(forFunction: "sum:", arguments: [countExpression]) + + let sumExpressionDescription = NSExpressionDescription() + sumExpressionDescription.name = "totalCount" + sumExpressionDescription.expression = sumExpression + sumExpressionDescription.expressionResultType = .integer64AttributeType + + request.propertiesToGroupBy = [companyNameKey] + request.propertiesToFetch = [companyNameKey, sumExpressionDescription] + request.resultType = .dictionaryResultType + + let results = (try context.fetch(request) as? [[String: Any]]) ?? [] + + let groupedResults = results.reduce(into: [String: Int64]()) { partialResult, result in + if let companyName = result[companyNameKey] as? String, let totalCount = result["totalCount"] as? Int64, totalCount > 0 { + partialResult[companyName] = totalCount + } + } + + return groupedResults + } + + /** + * Deletes stats older than 7 days for all companies. + */ + static func deleteOutdatedPacks(in context: NSManagedObjectContext) throws { + let oldestValidTimestamp = Date().privacyStatsOldestPackTimestamp + + let fetchRequest = NSFetchRequest(entityName: DailyBlockedTrackersEntity.Const.entityName) + fetchRequest.predicate = NSPredicate(format: "%K < %@", #keyPath(DailyBlockedTrackersEntity.timestamp), oldestValidTimestamp as NSDate) + let deleteRequest = NSBatchDeleteRequest(fetchRequest: fetchRequest) + + try context.execute(deleteRequest) + context.reset() + } + + /** + * Deletes all stats entries in the database. + */ + static func deleteAllStats(in context: NSManagedObjectContext) throws { + let deleteRequest = NSBatchDeleteRequest(fetchRequest: DailyBlockedTrackersEntity.fetchRequest()) + try context.execute(deleteRequest) + context.reset() + } +} diff --git a/Sources/RemoteMessaging/Mappers/DefaultRemoteMessagingSurveyURLBuilder.swift b/Sources/RemoteMessaging/Mappers/DefaultRemoteMessagingSurveyURLBuilder.swift index dc3f3e48b..37e53e352 100644 --- a/Sources/RemoteMessaging/Mappers/DefaultRemoteMessagingSurveyURLBuilder.swift +++ b/Sources/RemoteMessaging/Mappers/DefaultRemoteMessagingSurveyURLBuilder.swift @@ -31,11 +31,16 @@ public struct DefaultRemoteMessagingSurveyURLBuilder: RemoteMessagingSurveyActio private let statisticsStore: StatisticsStore private let vpnActivationDateStore: VPNActivationDateProviding private let subscription: Subscription? + private let localeIdentifier: String - public init(statisticsStore: StatisticsStore, vpnActivationDateStore: VPNActivationDateProviding, subscription: Subscription?) { + public init(statisticsStore: StatisticsStore, + vpnActivationDateStore: VPNActivationDateProviding, + subscription: Subscription?, + localeIdentifier: String = Locale.current.identifier) { self.statisticsStore = statisticsStore self.vpnActivationDateStore = vpnActivationDateStore self.subscription = subscription + self.localeIdentifier = localeIdentifier } // swiftlint:disable:next cyclomatic_complexity @@ -72,6 +77,9 @@ public struct DefaultRemoteMessagingSurveyURLBuilder: RemoteMessagingSurveyActio let daysSinceInstall = Calendar.current.numberOfDaysBetween(installDate, and: Date()) { queryItems.append(URLQueryItem(name: parameter.rawValue, value: String(describing: daysSinceInstall))) } + case .locale: + let formattedLocale = LocaleMatchingAttribute.localeIdentifierAsJsonFormat(localeIdentifier) + queryItems.append(URLQueryItem(name: parameter.rawValue, value: formattedLocale)) case .privacyProStatus: if let privacyProStatusSurveyParameter = subscription?.privacyProStatusSurveyParameter { queryItems.append(URLQueryItem(name: parameter.rawValue, value: privacyProStatusSurveyParameter)) diff --git a/Sources/RemoteMessaging/Mappers/RemoteMessagingSurveyActionMapping.swift b/Sources/RemoteMessaging/Mappers/RemoteMessagingSurveyActionMapping.swift index 668b03e1c..a927e1c47 100644 --- a/Sources/RemoteMessaging/Mappers/RemoteMessagingSurveyActionMapping.swift +++ b/Sources/RemoteMessaging/Mappers/RemoteMessagingSurveyActionMapping.swift @@ -24,6 +24,7 @@ public enum RemoteMessagingSurveyActionParameter: String, CaseIterable { case atbVariant = "var" case daysInstalled = "delta" case hardwareModel = "mo" + case locale = "locale" case osVersion = "osv" case privacyProStatus = "ppro_status" case privacyProPlatform = "ppro_platform" diff --git a/Sources/SpecialErrorPages/SSLErrorType.swift b/Sources/SpecialErrorPages/SSLErrorType.swift index cf483b8e2..58137a293 100644 --- a/Sources/SpecialErrorPages/SSLErrorType.swift +++ b/Sources/SpecialErrorPages/SSLErrorType.swift @@ -17,28 +17,27 @@ // import Foundation +import WebKit -public enum SSLErrorType: String { +public let SSLErrorCodeKey = "_kCFStreamErrorCodeKey" + +public enum SSLErrorType: String, Encodable { case expired - case wrongHost case selfSigned + case wrongHost case invalid - public static func forErrorCode(_ errorCode: Int) -> Self { - switch Int32(errorCode) { - case errSSLCertExpired: - return .expired - case errSSLHostNameMismatch: - return .wrongHost - case errSSLXCertChainInvalid: - return .selfSigned - default: - return .invalid + init(errorCode: Int32) { + self = switch errorCode { + case errSSLCertExpired: .expired + case errSSLXCertChainInvalid: .selfSigned + case errSSLHostNameMismatch: .wrongHost + default: .invalid } } - public var rawParameter: String { + public var pixelParameter: String { switch self { case .expired: return "expired" case .wrongHost: return "wrong_host" @@ -48,3 +47,16 @@ public enum SSLErrorType: String { } } + +extension WKError { + public var sslErrorType: SSLErrorType? { + _nsError.sslErrorType + } +} +extension NSError { + public var sslErrorType: SSLErrorType? { + guard let errorCode = self.userInfo[SSLErrorCodeKey] as? Int32 else { return nil } + let sslErrorType = SSLErrorType(errorCode: errorCode) + return sslErrorType + } +} diff --git a/Sources/SpecialErrorPages/SpecialErrorData.swift b/Sources/SpecialErrorPages/SpecialErrorData.swift index 7ceb0baef..048077847 100644 --- a/Sources/SpecialErrorPages/SpecialErrorData.swift +++ b/Sources/SpecialErrorPages/SpecialErrorData.swift @@ -17,24 +17,61 @@ // import Foundation +import MaliciousSiteProtection public enum SpecialErrorKind: String, Encodable { case ssl case phishing + // case malware } -public struct SpecialErrorData: Encodable, Equatable { +public enum SpecialErrorData: Encodable, Equatable { - var kind: SpecialErrorKind - var errorType: String? - var domain: String? - var eTldPlus1: String? + enum CodingKeys: CodingKey { + case kind + case errorType + case domain + case eTldPlus1 + case url + } + + case ssl(type: SSLErrorType, domain: String, eTldPlus1: String?) + case maliciousSite(kind: MaliciousSiteProtection.ThreatKind, url: URL) + + public func encode(to encoder: any Encoder) throws { + var container = encoder.container(keyedBy: CodingKeys.self) + + switch self { + case .ssl(type: let type, domain: let domain, eTldPlus1: let eTldPlus1): + try container.encode(SpecialErrorKind.ssl, forKey: .kind) + try container.encode(type, forKey: .errorType) + try container.encode(domain, forKey: .domain) - public init(kind: SpecialErrorKind, errorType: String? = nil, domain: String? = nil, eTldPlus1: String? = nil) { - self.kind = kind - self.errorType = errorType - self.domain = domain - self.eTldPlus1 = eTldPlus1 + switch type { + case .expired, .selfSigned, .invalid: break + case .wrongHost: + guard let eTldPlus1 else { + assertionFailure("expected eTldPlus1 != nil when kind is .wrongHost") + break + } + try container.encode(eTldPlus1, forKey: .eTldPlus1) + } + + case .maliciousSite(kind: let kind, url: let url): + // https://app.asana.com/0/1206594217596623/1208824527069247/f + var container = encoder.container(keyedBy: CodingKeys.self) + try container.encode(kind.errorPageKind, forKey: .kind) + try container.encode(url, forKey: .url) + } } } + +public extension MaliciousSiteProtection.ThreatKind { + var errorPageKind: SpecialErrorKind { + switch self { + // case .malware: .malware + case .phishing: .phishing + } + } +} diff --git a/Sources/SpecialErrorPages/SpecialErrorPageUserScript.swift b/Sources/SpecialErrorPages/SpecialErrorPageUserScript.swift index 71cad64e8..f131358ef 100644 --- a/Sources/SpecialErrorPages/SpecialErrorPageUserScript.swift +++ b/Sources/SpecialErrorPages/SpecialErrorPageUserScript.swift @@ -23,11 +23,11 @@ import Common public protocol SpecialErrorPageUserScriptDelegate: AnyObject { - var errorData: SpecialErrorData? { get } + @MainActor var errorData: SpecialErrorData? { get } - func leaveSite() - func visitSite() - func advancedInfoPresented() + @MainActor func leaveSiteAction() + @MainActor func visitSiteAction() + @MainActor func advancedInfoPresented() } @@ -105,13 +105,13 @@ public final class SpecialErrorPageUserScript: NSObject, Subfeature { @MainActor func handleLeaveSiteAction(params: Any, message: UserScriptMessage) -> Encodable? { - delegate?.leaveSite() + delegate?.leaveSiteAction() return nil } @MainActor func handleVisitSiteAction(params: Any, message: UserScriptMessage) -> Encodable? { - delegate?.visitSite() + delegate?.visitSiteAction() return nil } diff --git a/Sources/Subscription/API/Model/Entitlement.swift b/Sources/Subscription/API/Model/Entitlement.swift index c90e7342c..1d8eb645a 100644 --- a/Sources/Subscription/API/Model/Entitlement.swift +++ b/Sources/Subscription/API/Model/Entitlement.swift @@ -25,6 +25,7 @@ public struct Entitlement: Codable, Equatable { case networkProtection = "Network Protection" case dataBrokerProtection = "Data Broker Protection" case identityTheftRestoration = "Identity Theft Restoration" + case identityTheftRestorationGlobal = "Global Identity Theft Restoration" case unknown public init(from decoder: Decoder) throws { diff --git a/Sources/Subscription/API/SubscriptionEndpointService.swift b/Sources/Subscription/API/SubscriptionEndpointService.swift index 7898bbddb..552c28d4a 100644 --- a/Sources/Subscription/API/SubscriptionEndpointService.swift +++ b/Sources/Subscription/API/SubscriptionEndpointService.swift @@ -27,6 +27,10 @@ public struct GetProductsItem: Decodable { public let currency: String } +public struct GetSubscriptionFeaturesResponse: Decodable { + public let features: [Entitlement.ProductName] +} + public struct GetCustomerPortalURLResponse: Decodable { public let customerPortalUrl: String } @@ -47,6 +51,7 @@ public protocol SubscriptionEndpointService { func getSubscription(accessToken: String, cachePolicy: APICachePolicy) async -> Result func signOut() func getProducts() async -> Result<[GetProductsItem], APIServiceError> + func getSubscriptionFeatures(for subscriptionID: String) async -> Result func getCustomerPortalURL(accessToken: String, externalID: String) async -> Result func confirmPurchase(accessToken: String, signature: String) async -> Result } @@ -137,6 +142,12 @@ public struct DefaultSubscriptionEndpointService: SubscriptionEndpointService { // MARK: - + public func getSubscriptionFeatures(for subscriptionID: String) async -> Result { + await apiService.executeAPICall(method: "GET", endpoint: "products/\(subscriptionID)/features", headers: nil, body: nil) + } + + // MARK: - + public func getCustomerPortalURL(accessToken: String, externalID: String) async -> Result { var headers = apiService.makeAuthorizationHeader(for: accessToken) headers["externalAccountId"] = externalID diff --git a/Tests/PhishingDetectionTests/Mocks/PhishingDetectorMock.swift b/Sources/Subscription/FeatureFlags/FeatureFlaggerMapping.swift similarity index 53% rename from Tests/PhishingDetectionTests/Mocks/PhishingDetectorMock.swift rename to Sources/Subscription/FeatureFlags/FeatureFlaggerMapping.swift index 4a56474e0..f2541131e 100644 --- a/Tests/PhishingDetectionTests/Mocks/PhishingDetectorMock.swift +++ b/Sources/Subscription/FeatureFlags/FeatureFlaggerMapping.swift @@ -1,5 +1,5 @@ // -// PhishingDetectorMock.swift +// FeatureFlaggerMapping.swift // // Copyright © 2024 DuckDuckGo. All rights reserved. // @@ -17,22 +17,17 @@ // import Foundation -import PhishingDetection -public class MockPhishingDetector: PhishingDetecting { - private var mockClient: PhishingDetectionClientProtocol - public var didCallIsMalicious: Bool = false +open class FeatureFlaggerMapping { + public typealias Mapping = (_ feature: Feature) -> Bool - init() { - self.mockClient = MockPhishingDetectionClient() - } + private let isFeatureEnabledMapping: Mapping - public func getMatches(hashPrefix: String) async -> Set { - let matches = await mockClient.getMatches(hashPrefix: hashPrefix) - return Set(matches) + public init(mapping: @escaping Mapping) { + isFeatureEnabledMapping = mapping } - public func isMalicious(url: URL) async -> Bool { - return url.absoluteString.contains("malicious") + public func isFeatureOn(_ feature: Feature) -> Bool { + return isFeatureEnabledMapping(feature) } } diff --git a/Sources/Subscription/Flows/Models/SubscriptionEnvironmentNames.swift b/Sources/Subscription/FeatureFlags/SubscriptionFeatureFlags.swift similarity index 54% rename from Sources/Subscription/Flows/Models/SubscriptionEnvironmentNames.swift rename to Sources/Subscription/FeatureFlags/SubscriptionFeatureFlags.swift index 540c0c2e8..104d20066 100644 --- a/Sources/Subscription/Flows/Models/SubscriptionEnvironmentNames.swift +++ b/Sources/Subscription/FeatureFlags/SubscriptionFeatureFlags.swift @@ -1,5 +1,5 @@ // -// SubscriptionEnvironmentNames.swift +// SubscriptionFeatureFlags.swift // // Copyright © 2024 DuckDuckGo. All rights reserved. // @@ -18,18 +18,21 @@ import Foundation -public enum SubscriptionFeatureName: String, CaseIterable { - case privateBrowsing = "private-browsing" - case privateSearch = "private-search" - case emailProtection = "email-protection" - case appTrackingProtection = "app-tracking-protection" - case vpn = "vpn" - case personalInformationRemoval = "personal-information-removal" - case identityTheftRestoration = "identity-theft-restoration" +public enum SubscriptionFeatureFlags { + case isLaunchedROW + case isLaunchedROWOverride + case usePrivacyProUSARegionOverride + case usePrivacyProROWRegionOverride } -public enum SubscriptionPlatformName: String { - case ios - case macos - case stripe +public extension SubscriptionFeatureFlags { + + var defaultState: Bool { + switch self { + case .isLaunchedROW, .isLaunchedROWOverride: + return true + case .usePrivacyProUSARegionOverride, .usePrivacyProROWRegionOverride: + return false + } + } } diff --git a/Sources/Subscription/Flows/Models/SubscriptionOptions.swift b/Sources/Subscription/Flows/Models/SubscriptionOptions.swift index ca0347384..5544cf680 100644 --- a/Sources/Subscription/Flows/Models/SubscriptionOptions.swift +++ b/Sources/Subscription/Flows/Models/SubscriptionOptions.swift @@ -19,21 +19,34 @@ import Foundation public struct SubscriptionOptions: Encodable, Equatable { - let platform: String + let platform: SubscriptionPlatformName let options: [SubscriptionOption] let features: [SubscriptionFeature] + public static var empty: SubscriptionOptions { - let features = SubscriptionFeatureName.allCases.map { SubscriptionFeature(name: $0.rawValue) } + let features = [SubscriptionFeature(name: .networkProtection), + SubscriptionFeature(name: .dataBrokerProtection), + SubscriptionFeature(name: .identityTheftRestoration)] let platform: SubscriptionPlatformName #if os(iOS) platform = .ios #else platform = .macos #endif - return SubscriptionOptions(platform: platform.rawValue, options: [], features: features) + return SubscriptionOptions(platform: platform, options: [], features: features) + } + + public func withoutPurchaseOptions() -> Self { + SubscriptionOptions(platform: platform, options: [], features: features) } } +public enum SubscriptionPlatformName: String, Encodable { + case ios + case macos + case stripe +} + public struct SubscriptionOption: Encodable, Equatable { let id: String let cost: SubscriptionOptionCost @@ -45,5 +58,5 @@ struct SubscriptionOptionCost: Encodable, Equatable { } public struct SubscriptionFeature: Encodable, Equatable { - let name: String + let name: Entitlement.ProductName } diff --git a/Sources/Subscription/Flows/Stripe/StripePurchaseFlow.swift b/Sources/Subscription/Flows/Stripe/StripePurchaseFlow.swift index 3912d96a2..43e0448e7 100644 --- a/Sources/Subscription/Flows/Stripe/StripePurchaseFlow.swift +++ b/Sources/Subscription/Flows/Stripe/StripePurchaseFlow.swift @@ -71,9 +71,11 @@ public final class DefaultStripePurchaseFlow: StripePurchaseFlow { cost: cost) } - let features = SubscriptionFeatureName.allCases.map { SubscriptionFeature(name: $0.rawValue) } + let features = [SubscriptionFeature(name: .networkProtection), + SubscriptionFeature(name: .dataBrokerProtection), + SubscriptionFeature(name: .identityTheftRestoration)] - return .success(SubscriptionOptions(platform: SubscriptionPlatformName.stripe.rawValue, + return .success(SubscriptionOptions(platform: SubscriptionPlatformName.stripe, options: options, features: features)) } diff --git a/Sources/Subscription/Managers/StorePurchaseManager.swift b/Sources/Subscription/Managers/StorePurchaseManager.swift index 24ebc3d31..47791eb22 100644 --- a/Sources/Subscription/Managers/StorePurchaseManager.swift +++ b/Sources/Subscription/Managers/StorePurchaseManager.swift @@ -41,6 +41,7 @@ public protocol StorePurchaseManager { var purchasedProductIDs: [String] { get } var purchaseQueue: [String] { get } var areProductsAvailable: Bool { get } + var currentStorefrontRegion: SubscriptionRegion { get } @MainActor func syncAppleIDAccount() async throws @MainActor func updateAvailableProducts() async @@ -56,21 +57,24 @@ public protocol StorePurchaseManager { @available(macOS 12.0, iOS 15.0, *) public final class DefaultStorePurchaseManager: ObservableObject, StorePurchaseManager { - let productIdentifiers = ["ios.subscription.1month", "ios.subscription.1year", - "subscription.1month", "subscription.1year", - "review.subscription.1month", "review.subscription.1year", - "tf.sandbox.subscription.1month", "tf.sandbox.subscription.1year", - "ddg.privacy.pro.monthly.renews.us", "ddg.privacy.pro.yearly.renews.us"] + private let storeSubscriptionConfiguration: StoreSubscriptionConfiguration + private let subscriptionFeatureMappingCache: SubscriptionFeatureMappingCache + private let subscriptionFeatureFlagger: FeatureFlaggerMapping? @Published public private(set) var availableProducts: [Product] = [] @Published public private(set) var purchasedProductIDs: [String] = [] @Published public private(set) var purchaseQueue: [String] = [] public var areProductsAvailable: Bool { !availableProducts.isEmpty } + public private(set) var currentStorefrontRegion: SubscriptionRegion = .usa private var transactionUpdates: Task? private var storefrontChanges: Task? - public init() { + public init(subscriptionFeatureMappingCache: SubscriptionFeatureMappingCache, + subscriptionFeatureFlagger: FeatureFlaggerMapping? = nil) { + self.storeSubscriptionConfiguration = DefaultStoreSubscriptionConfiguration() + self.subscriptionFeatureMappingCache = subscriptionFeatureMappingCache + self.subscriptionFeatureFlagger = subscriptionFeatureFlagger transactionUpdates = observeTransactionUpdates() storefrontChanges = observeStorefrontChanges() } @@ -109,17 +113,29 @@ public final class DefaultStorePurchaseManager: ObservableObject, StorePurchaseM return nil } - let options = [SubscriptionOption(id: monthly.id, cost: .init(displayPrice: monthly.displayPrice, recurrence: "monthly")), - SubscriptionOption(id: yearly.id, cost: .init(displayPrice: yearly.displayPrice, recurrence: "yearly"))] - let features = SubscriptionFeatureName.allCases.map { SubscriptionFeature(name: $0.rawValue) } - let platform: SubscriptionPlatformName - + let platform: SubscriptionPlatformName = { #if os(iOS) - platform = .ios + .ios #else - platform = .macos + .macos #endif - return SubscriptionOptions(platform: platform.rawValue, + }() + + let options = [SubscriptionOption(id: monthly.id, + cost: .init(displayPrice: monthly.displayPrice, recurrence: "monthly")), + SubscriptionOption(id: yearly.id, + cost: .init(displayPrice: yearly.displayPrice, recurrence: "yearly"))] + + let features: [SubscriptionFeature] + + if let featureFlagger = subscriptionFeatureFlagger, featureFlagger.isFeatureOn(.isLaunchedROW) || featureFlagger.isFeatureOn(.isLaunchedROWOverride) { + features = await subscriptionFeatureMappingCache.subscriptionFeatures(for: monthly.id).compactMap { SubscriptionFeature(name: $0) } + } else { + let allFeatures: [Entitlement.ProductName] = [.networkProtection, .dataBrokerProtection, .identityTheftRestoration] + features = allFeatures.compactMap { SubscriptionFeature(name: $0) } + } + + return SubscriptionOptions(platform: platform, options: options, features: features) } @@ -129,11 +145,36 @@ public final class DefaultStorePurchaseManager: ObservableObject, StorePurchaseM Logger.subscription.info("[StorePurchaseManager] updateAvailableProducts") do { - let availableProducts = try await Product.products(for: productIdentifiers) - Logger.subscription.info("[StorePurchaseManager] updateAvailableProducts fetched \(availableProducts.count) products") + let storefrontCountryCode: String? + let storefrontRegion: SubscriptionRegion + + if let featureFlagger = subscriptionFeatureFlagger, featureFlagger.isFeatureOn(.isLaunchedROW) || featureFlagger.isFeatureOn(.isLaunchedROWOverride) { + if let subscriptionFeatureFlagger, subscriptionFeatureFlagger.isFeatureOn(.usePrivacyProUSARegionOverride) { + storefrontCountryCode = "USA" + } else if let subscriptionFeatureFlagger, subscriptionFeatureFlagger.isFeatureOn(.usePrivacyProROWRegionOverride) { + storefrontCountryCode = "POL" + } else { + storefrontCountryCode = await Storefront.current?.countryCode + } + + storefrontRegion = SubscriptionRegion.matchingRegion(for: storefrontCountryCode ?? "USA") ?? .usa // Fallback to USA + } else { + storefrontCountryCode = "USA" + storefrontRegion = .usa + } + + self.currentStorefrontRegion = storefrontRegion + let applicableProductIdentifiers = storeSubscriptionConfiguration.subscriptionIdentifiers(for: storefrontRegion) + let availableProducts = try await Product.products(for: applicableProductIdentifiers) + Logger.subscription.info("[StorePurchaseManager] updateAvailableProducts fetched \(availableProducts.count) products for \(storefrontCountryCode ?? "", privacy: .public)") if self.availableProducts != availableProducts { self.availableProducts = availableProducts + + // Update cached subscription features mapping + for id in availableProducts.compactMap({ $0.id }) { + _ = await subscriptionFeatureMappingCache.subscriptionFeatures(for: id) + } } } catch { Logger.subscription.error("[StorePurchaseManager] Error: \(String(reflecting: error), privacy: .public)") @@ -295,3 +336,36 @@ public final class DefaultStorePurchaseManager: ObservableObject, StorePurchaseM } } } + +public extension UserDefaults { + + enum Constants { + static let storefrontRegionOverrideKey = "Subscription.debug.storefrontRegionOverride" + static let usaValue = "usa" + static let rowValue = "row" + } + + dynamic var storefrontRegionOverride: SubscriptionRegion? { + get { + switch string(forKey: Constants.storefrontRegionOverrideKey) { + case "usa": + return .usa + case "row": + return .restOfWorld + default: + return nil + } + } + + set { + switch newValue { + case .usa: + set(Constants.usaValue, forKey: Constants.storefrontRegionOverrideKey) + case .restOfWorld: + set(Constants.rowValue, forKey: Constants.storefrontRegionOverrideKey) + default: + removeObject(forKey: Constants.storefrontRegionOverrideKey) + } + } + } +} diff --git a/Sources/Subscription/Managers/SubscriptionManager.swift b/Sources/Subscription/Managers/SubscriptionManager.swift index ef8abfe85..cac861106 100644 --- a/Sources/Subscription/Managers/SubscriptionManager.swift +++ b/Sources/Subscription/Managers/SubscriptionManager.swift @@ -24,6 +24,7 @@ public protocol SubscriptionManager { var accountManager: AccountManager { get } var subscriptionEndpointService: SubscriptionEndpointService { get } var authEndpointService: AuthEndpointService { get } + var subscriptionFeatureMappingCache: SubscriptionFeatureMappingCache { get } // Environment static func loadEnvironmentFrom(userDefaults: UserDefaults) -> SubscriptionEnvironment? @@ -35,6 +36,7 @@ public protocol SubscriptionManager { func loadInitialData() func refreshCachedSubscriptionAndEntitlements(completion: @escaping (_ isSubscriptionActive: Bool) -> Void) func url(for type: SubscriptionURL) -> URL + func currentSubscriptionFeatures() async -> [Entitlement.ProductName] } /// Single entry point for everything related to Subscription. This manager is disposable, every time something related to the environment changes this need to be recreated. @@ -43,19 +45,26 @@ public final class DefaultSubscriptionManager: SubscriptionManager { public let accountManager: AccountManager public let subscriptionEndpointService: SubscriptionEndpointService public let authEndpointService: AuthEndpointService + public let subscriptionFeatureMappingCache: SubscriptionFeatureMappingCache public let currentEnvironment: SubscriptionEnvironment - public private(set) var canPurchase: Bool = false + + private let subscriptionFeatureFlagger: FeatureFlaggerMapping public init(storePurchaseManager: StorePurchaseManager? = nil, accountManager: AccountManager, subscriptionEndpointService: SubscriptionEndpointService, authEndpointService: AuthEndpointService, - subscriptionEnvironment: SubscriptionEnvironment) { + subscriptionFeatureMappingCache: SubscriptionFeatureMappingCache, + subscriptionEnvironment: SubscriptionEnvironment, + subscriptionFeatureFlagger: FeatureFlaggerMapping) { self._storePurchaseManager = storePurchaseManager self.accountManager = accountManager self.subscriptionEndpointService = subscriptionEndpointService self.authEndpointService = authEndpointService + self.subscriptionFeatureMappingCache = subscriptionFeatureMappingCache self.currentEnvironment = subscriptionEnvironment + self.subscriptionFeatureFlagger = subscriptionFeatureFlagger + switch currentEnvironment.purchasePlatform { case .appStore: if #available(macOS 12.0, iOS 15.0, *) { @@ -68,6 +77,21 @@ public final class DefaultSubscriptionManager: SubscriptionManager { } } + public var canPurchase: Bool { + guard let storePurchaseManager = _storePurchaseManager else { return false } + + switch storePurchaseManager.currentStorefrontRegion { + case .usa: + return storePurchaseManager.areProductsAvailable + case .restOfWorld: + if subscriptionFeatureFlagger.isFeatureOn(.isLaunchedROW) || subscriptionFeatureFlagger.isFeatureOn(.isLaunchedROWOverride) { + return storePurchaseManager.areProductsAvailable + } else { + return false + } + } + } + @available(macOS 12.0, iOS 15.0, *) public func storePurchaseManager() -> StorePurchaseManager { return _storePurchaseManager! @@ -98,7 +122,6 @@ public final class DefaultSubscriptionManager: SubscriptionManager { @available(macOS 12.0, iOS 15.0, *) private func setupForAppStore() { Task { await storePurchaseManager().updateAvailableProducts() - canPurchase = storePurchaseManager().areProductsAvailable } } @@ -147,4 +170,21 @@ public final class DefaultSubscriptionManager: SubscriptionManager { public func url(for type: SubscriptionURL) -> URL { type.subscriptionURL(environment: currentEnvironment.serviceEnvironment) } + + // MARK: - Current subscription's features + + public func currentSubscriptionFeatures() async -> [Entitlement.ProductName] { + guard let token = accountManager.accessToken else { return [] } + + if subscriptionFeatureFlagger.isFeatureOn(.isLaunchedROW) || subscriptionFeatureFlagger.isFeatureOn(.isLaunchedROWOverride) { + switch await subscriptionEndpointService.getSubscription(accessToken: token, cachePolicy: .returnCacheDataElseLoad) { + case .success(let subscription): + return await subscriptionFeatureMappingCache.subscriptionFeatures(for: subscription.productId) + case .failure: + return [] + } + } else { + return [.networkProtection, .dataBrokerProtection, .identityTheftRestoration] + } + } } diff --git a/Sources/Subscription/NSNotificationName+Subscription.swift b/Sources/Subscription/NSNotificationName+Subscription.swift index 1079c6231..b674aaff7 100644 --- a/Sources/Subscription/NSNotificationName+Subscription.swift +++ b/Sources/Subscription/NSNotificationName+Subscription.swift @@ -20,10 +20,6 @@ import Foundation public extension NSNotification.Name { - static let openPrivateBrowsing = Notification.Name("com.duckduckgo.subscription.open.private-browsing") - static let openPrivateSearch = Notification.Name("com.duckduckgo.subscription.open.private-search") - static let openEmailProtection = Notification.Name("com.duckduckgo.subscription.open.email-protection") - static let openAppTrackingProtection = Notification.Name("com.duckduckgo.subscription.open.app-tracking-protection") static let openVPN = Notification.Name("com.duckduckgo.subscription.open.vpn") static let openPersonalInformationRemoval = Notification.Name("com.duckduckgo.subscription.open.personal-information-removal") static let openIdentityTheftRestoration = Notification.Name("com.duckduckgo.subscription.open.identity-theft-restoration") diff --git a/Sources/Subscription/StoreSubscriptionConfiguration.swift b/Sources/Subscription/StoreSubscriptionConfiguration.swift new file mode 100644 index 000000000..8dacf7db4 --- /dev/null +++ b/Sources/Subscription/StoreSubscriptionConfiguration.swift @@ -0,0 +1,138 @@ +// +// StoreSubscriptionConfiguration.swift +// +// Copyright © 2023 DuckDuckGo. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +import Foundation +import Combine + +protocol StoreSubscriptionConfiguration { + var allSubscriptionIdentifiers: [String] { get } + func subscriptionIdentifiers(for country: String) -> [String] + func subscriptionIdentifiers(for region: SubscriptionRegion) -> [String] +} + +final class DefaultStoreSubscriptionConfiguration: StoreSubscriptionConfiguration { + + private let subscriptions: [StoreSubscriptionDefinition] + + convenience init() { + self.init(subscriptionDefinitions: [ + // Production shared for iOS and macOS + .init(name: "DuckDuckGo Private Browser", + appIdentifier: "com.duckduckgo.mobile.ios", + environment: .production, + identifiersByRegion: [.usa: ["ddg.privacy.pro.monthly.renews.us", + "ddg.privacy.pro.yearly.renews.us"], + .restOfWorld: ["ddg.privacy.pro.monthly.renews.row", + "ddg.privacy.pro.yearly.renews.row"]]), + // iOS debug Alpha build + .init(name: "DuckDuckGo Alpha", + appIdentifier: "com.duckduckgo.mobile.ios.alpha", + environment: .staging, + identifiersByRegion: [.usa: ["ios.subscription.1month", + "ios.subscription.1year"], + .restOfWorld: ["ios.subscription.1month.row", + "ios.subscription.1year.row"]]), + // macOS debug build + .init(name: "IAP debug - DDG for macOS", + appIdentifier: "com.duckduckgo.macos.browser.debug", + environment: .staging, + identifiersByRegion: [.usa: ["subscription.1month", + "subscription.1year"], + .restOfWorld: ["subscription.1month.row", + "subscription.1year.row"]]), + // macOS review build + .init(name: "IAP review - DDG for macOS", + appIdentifier: "com.duckduckgo.macos.browser.review", + environment: .staging, + identifiersByRegion: [.usa: ["review.subscription.1month", + "review.subscription.1year"], + .restOfWorld: ["review.subscription.1month.row", + "review.subscription.1year.row"]]), + + // macOS TestFlight build + .init(name: "DuckDuckGo Sandbox Review", + appIdentifier: "com.duckduckgo.mobile.ios.review", + environment: .staging, + identifiersByRegion: [.usa: ["tf.sandbox.subscription.1month", + "tf.sandbox.subscription.1year"], + .restOfWorld: ["tf.sandbox.subscription.1month.row", + "tf.sandbox.subscription.1year.row"]]) + ]) + } + + init(subscriptionDefinitions: [StoreSubscriptionDefinition]) { + self.subscriptions = subscriptionDefinitions + } + + var allSubscriptionIdentifiers: [String] { + subscriptions.reduce([], { $0 + $1.allIdentifiers() }) + } + + func subscriptionIdentifiers(for country: String) -> [String] { + subscriptions.reduce([], { $0 + $1.identifiers(for: country) }) + } + + func subscriptionIdentifiers(for region: SubscriptionRegion) -> [String] { + subscriptions.reduce([], { $0 + $1.identifiers(for: region) }) + } +} + +struct StoreSubscriptionDefinition { + var name: String + var appIdentifier: String + var environment: SubscriptionEnvironment.ServiceEnvironment + var identifiersByRegion: [SubscriptionRegion: [String]] + + func allIdentifiers() -> [String] { + identifiersByRegion.values.flatMap { $0 } + } + + func identifiers(for country: String) -> [String] { + identifiersByRegion.filter { region, _ in region.contains(country) }.flatMap { _, identifiers in identifiers } + } + + func identifiers(for region: SubscriptionRegion) -> [String] { + identifiersByRegion[region] ?? [] + } +} + +public enum SubscriptionRegion: CaseIterable { + case usa + case restOfWorld + + /// Country codes as used by StoreKit, in the ISO 3166-1 Alpha-3 country code representation + /// For .restOfWorld definiton see https://app.asana.com/0/1208524871249522/1208571752166956/f + var countryCodes: Set { + switch self { + case .usa: + return Set(["USA"]) + case .restOfWorld: + return Set(["CAN", "GBR", "AUT", "DEU", "NLD", "POL", "SWE", + "BEL", "BGR", "HRV ", "CYP", "CZE", "DNK", "EST", "FIN", "FRA", "GRC", "HUN", "IRL", "ITA", "LVA", "LTU", "LUX", "MLT", "PRT", + "ROU", "SVK", "SVN", "ESP"]) + } + } + + func contains(_ country: String) -> Bool { + countryCodes.contains(country.uppercased()) + } + + static func matchingRegion(for countryCode: String) -> Self? { + Self.allCases.first { $0.countryCodes.contains(countryCode) } + } +} diff --git a/Sources/Subscription/SubscriptionCookie/HTTPCookieStore.swift b/Sources/Subscription/SubscriptionCookie/HTTPCookieStore.swift index 74fec5568..3b561c637 100644 --- a/Sources/Subscription/SubscriptionCookie/HTTPCookieStore.swift +++ b/Sources/Subscription/SubscriptionCookie/HTTPCookieStore.swift @@ -25,4 +25,36 @@ public protocol HTTPCookieStore { func deleteCookie(_ cookie: HTTPCookie) async } -extension WKHTTPCookieStore: HTTPCookieStore {} +@MainActor +public struct WKHTTPCookieStoreWrapper: HTTPCookieStore { + + let store: WKHTTPCookieStore + + public init(store: WKHTTPCookieStore) { + self.store = store + } + + public func allCookies() async -> [HTTPCookie] { + return await withCheckedContinuation { continuation in + store.getAllCookies { cookies in + continuation.resume(returning: cookies) + } + } + } + + public func setCookie(_ cookie: HTTPCookie) async { + await withCheckedContinuation { continuation in + store.setCookie(cookie) { + continuation.resume() + } + } + } + + public func deleteCookie(_ cookie: HTTPCookie) async { + await withCheckedContinuation { continuation in + store.delete(cookie) { + continuation.resume() + } + } + } +} diff --git a/Sources/Subscription/SubscriptionFeatureMappingCache.swift b/Sources/Subscription/SubscriptionFeatureMappingCache.swift new file mode 100644 index 000000000..414fec1d7 --- /dev/null +++ b/Sources/Subscription/SubscriptionFeatureMappingCache.swift @@ -0,0 +1,136 @@ +// +// SubscriptionFeatureMappingCache.swift +// +// Copyright © 2023 DuckDuckGo. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +import Foundation +import os.log + +public protocol SubscriptionFeatureMappingCache { + func subscriptionFeatures(for subscriptionIdentifier: String) async -> [Entitlement.ProductName] +} + +public final class DefaultSubscriptionFeatureMappingCache: SubscriptionFeatureMappingCache { + + private let subscriptionEndpointService: SubscriptionEndpointService + private let userDefaults: UserDefaults + + private var subscriptionFeatureMapping: SubscriptionFeatureMapping? + + public init(subscriptionEndpointService: SubscriptionEndpointService, userDefaults: UserDefaults) { + self.subscriptionEndpointService = subscriptionEndpointService + self.userDefaults = userDefaults + } + + public func subscriptionFeatures(for subscriptionIdentifier: String) async -> [Entitlement.ProductName] { + Logger.subscription.debug("[SubscriptionFeatureMappingCache] \(#function) \(subscriptionIdentifier)") + let features: [Entitlement.ProductName] + + if let subscriptionFeatures = currentSubscriptionFeatureMapping[subscriptionIdentifier] { + Logger.subscription.debug("[SubscriptionFeatureMappingCache] - got cached features") + features = subscriptionFeatures + } else if let subscriptionFeatures = await fetchRemoteFeatures(for: subscriptionIdentifier) { + Logger.subscription.debug("[SubscriptionFeatureMappingCache] - fetching features from BE API") + features = subscriptionFeatures + updateCachedFeatureMapping(with: subscriptionFeatures, for: subscriptionIdentifier) + } else { + Logger.subscription.debug("[SubscriptionFeatureMappingCache] - Error: using fallback") + features = fallbackFeatures + } + + return features + } + + // MARK: - Current feature mapping + + private var currentSubscriptionFeatureMapping: SubscriptionFeatureMapping { + Logger.subscription.debug("[SubscriptionFeatureMappingCache] - \(#function)") + let featureMapping: SubscriptionFeatureMapping + + if let cachedFeatureMapping { + Logger.subscription.debug("[SubscriptionFeatureMappingCache] -- got cachedFeatureMapping") + featureMapping = cachedFeatureMapping + } else if let storedFeatureMapping { + Logger.subscription.debug("[SubscriptionFeatureMappingCache] -- have to fetchStoredFeatureMapping") + featureMapping = storedFeatureMapping + updateCachedFeatureMapping(to: featureMapping) + } else { + Logger.subscription.debug("[SubscriptionFeatureMappingCache] -- so creating a new one!") + featureMapping = SubscriptionFeatureMapping() + updateCachedFeatureMapping(to: featureMapping) + } + + return featureMapping + } + + // MARK: - Cached subscription feature mapping + + private var cachedFeatureMapping: SubscriptionFeatureMapping? + + private func updateCachedFeatureMapping(to featureMapping: SubscriptionFeatureMapping) { + cachedFeatureMapping = featureMapping + } + + private func updateCachedFeatureMapping(with features: [Entitlement.ProductName], for subscriptionIdentifier: String) { + var updatedFeatureMapping = cachedFeatureMapping ?? SubscriptionFeatureMapping() + updatedFeatureMapping[subscriptionIdentifier] = features + + self.cachedFeatureMapping = updatedFeatureMapping + self.storedFeatureMapping = updatedFeatureMapping + } + + // MARK: - Stored subscription feature mapping + + static private let subscriptionFeatureMappingKey = "com.duckduckgo.subscription.featuremapping" + + dynamic var storedFeatureMapping: SubscriptionFeatureMapping? { + get { + guard let data = userDefaults.data(forKey: Self.subscriptionFeatureMappingKey) else { return nil } + do { + return try JSONDecoder().decode(SubscriptionFeatureMapping?.self, from: data) + } catch { + assertionFailure("Errored while decoding feature mapping") + return nil + } + } + + set { + do { + let data = try JSONEncoder().encode(newValue) + userDefaults.set(data, forKey: Self.subscriptionFeatureMappingKey) + } catch { + assertionFailure("Errored while encoding feature mapping") + } + } + } + + // MARK: - Remote subscription feature mapping + + private func fetchRemoteFeatures(for subscriptionIdentifier: String) async -> [Entitlement.ProductName]? { + if case let .success(response) = await subscriptionEndpointService.getSubscriptionFeatures(for: subscriptionIdentifier) { + Logger.subscription.debug("[SubscriptionFeatureMappingCache] -- Fetched features for `\(subscriptionIdentifier)`: \(response.features)") + return response.features + } + + return nil + } + + // MARK: - Fallback subscription feature mapping + + private let fallbackFeatures: [Entitlement.ProductName] = [.networkProtection, .dataBrokerProtection, .identityTheftRestoration] +} + +typealias SubscriptionFeatureMapping = [String: [Entitlement.ProductName]] diff --git a/Sources/SubscriptionTestingUtilities/APIs/SubscriptionEndpointServiceMock.swift b/Sources/SubscriptionTestingUtilities/APIs/SubscriptionEndpointServiceMock.swift index 132685772..b17d585d0 100644 --- a/Sources/SubscriptionTestingUtilities/APIs/SubscriptionEndpointServiceMock.swift +++ b/Sources/SubscriptionTestingUtilities/APIs/SubscriptionEndpointServiceMock.swift @@ -22,6 +22,7 @@ import Subscription public final class SubscriptionEndpointServiceMock: SubscriptionEndpointService { public var getSubscriptionResult: Result? public var getProductsResult: Result<[GetProductsItem], APIServiceError>? + public var getSubscriptionFeaturesResult: Result? public var getCustomerPortalURLResult: Result? public var confirmPurchaseResult: Result? @@ -55,6 +56,10 @@ public final class SubscriptionEndpointServiceMock: SubscriptionEndpointService getProductsResult! } + public func getSubscriptionFeatures(for subscriptionID: String) async -> Result { + getSubscriptionFeaturesResult! + } + public func getCustomerPortalURL(accessToken: String, externalID: String) async -> Result { getCustomerPortalURLResult! } diff --git a/Sources/SubscriptionTestingUtilities/Managers/StorePurchaseManagerMock.swift b/Sources/SubscriptionTestingUtilities/Managers/StorePurchaseManagerMock.swift index 6f654ba72..e326fd720 100644 --- a/Sources/SubscriptionTestingUtilities/Managers/StorePurchaseManagerMock.swift +++ b/Sources/SubscriptionTestingUtilities/Managers/StorePurchaseManagerMock.swift @@ -23,6 +23,8 @@ public final class StorePurchaseManagerMock: StorePurchaseManager { public var purchasedProductIDs: [String] = [] public var purchaseQueue: [String] = [] public var areProductsAvailable: Bool = false + public var currentStorefrontRegion: SubscriptionRegion = .usa + public var subscriptionOptionsResult: SubscriptionOptions? public var syncAppleIDAccountResultError: Error? diff --git a/Sources/SubscriptionTestingUtilities/Managers/SubscriptionManagerMock.swift b/Sources/SubscriptionTestingUtilities/Managers/SubscriptionManagerMock.swift index 46fdc77e0..3217eedbb 100644 --- a/Sources/SubscriptionTestingUtilities/Managers/SubscriptionManagerMock.swift +++ b/Sources/SubscriptionTestingUtilities/Managers/SubscriptionManagerMock.swift @@ -23,6 +23,7 @@ public final class SubscriptionManagerMock: SubscriptionManager { public var accountManager: AccountManager public var subscriptionEndpointService: SubscriptionEndpointService public var authEndpointService: AuthEndpointService + public var subscriptionFeatureMappingCache: SubscriptionFeatureMappingCache public static var storedEnvironment: SubscriptionEnvironment? public static func loadEnvironmentFrom(userDefaults: UserDefaults) -> SubscriptionEnvironment? { @@ -52,18 +53,24 @@ public final class SubscriptionManagerMock: SubscriptionManager { type.subscriptionURL(environment: currentEnvironment.serviceEnvironment) } + public func currentSubscriptionFeatures() async -> [Entitlement.ProductName] { + return [] + } + public init(accountManager: AccountManager, subscriptionEndpointService: SubscriptionEndpointService, authEndpointService: AuthEndpointService, storePurchaseManager: StorePurchaseManager, currentEnvironment: SubscriptionEnvironment, - canPurchase: Bool) { + canPurchase: Bool, + subscriptionFeatureMappingCache: SubscriptionFeatureMappingCache) { self.accountManager = accountManager self.subscriptionEndpointService = subscriptionEndpointService self.authEndpointService = authEndpointService self.internalStorePurchaseManager = storePurchaseManager self.currentEnvironment = currentEnvironment self.canPurchase = canPurchase + self.subscriptionFeatureMappingCache = subscriptionFeatureMappingCache } // MARK: - diff --git a/Sources/SubscriptionTestingUtilities/SubscriptionCookie/SubscriptionCookieManagerMock.swift b/Sources/SubscriptionTestingUtilities/SubscriptionCookie/SubscriptionCookieManagerMock.swift index a9166689a..b2a5b8133 100644 --- a/Sources/SubscriptionTestingUtilities/SubscriptionCookie/SubscriptionCookieManagerMock.swift +++ b/Sources/SubscriptionTestingUtilities/SubscriptionCookie/SubscriptionCookieManagerMock.swift @@ -29,13 +29,15 @@ public final class SubscriptionCookieManagerMock: SubscriptionCookieManaging { let subscriptionService = DefaultSubscriptionEndpointService(currentServiceEnvironment: .production) let authService = DefaultAuthEndpointService(currentServiceEnvironment: .production) let storePurchaseManager = StorePurchaseManagerMock() + let subscriptionFeatureMappingCache = SubscriptionFeatureMappingCacheMock() let subscriptionManager = SubscriptionManagerMock(accountManager: accountManager, subscriptionEndpointService: subscriptionService, authEndpointService: authService, storePurchaseManager: storePurchaseManager, currentEnvironment: SubscriptionEnvironment(serviceEnvironment: .production, purchasePlatform: .appStore), - canPurchase: true) + canPurchase: true, + subscriptionFeatureMappingCache: subscriptionFeatureMappingCache) self.init(subscriptionManager: subscriptionManager, currentCookieStore: { return nil }, diff --git a/Sources/SubscriptionTestingUtilities/SubscriptionFeatureMappingCacheMock.swift b/Sources/SubscriptionTestingUtilities/SubscriptionFeatureMappingCacheMock.swift new file mode 100644 index 000000000..c4612a9bd --- /dev/null +++ b/Sources/SubscriptionTestingUtilities/SubscriptionFeatureMappingCacheMock.swift @@ -0,0 +1,31 @@ +// +// SubscriptionFeatureMappingCacheMock.swift +// +// Copyright © 2024 DuckDuckGo. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +import Foundation +import Subscription + +public final class SubscriptionFeatureMappingCacheMock: SubscriptionFeatureMappingCache { + + public var mapping: [String: [Entitlement.ProductName]] = [:] + + public init() { } + + public func subscriptionFeatures(for subscriptionIdentifier: String) async -> [Entitlement.ProductName] { + return mapping[subscriptionIdentifier] ?? [] + } +} diff --git a/Sources/TestUtils/MockAPIService.swift b/Sources/TestUtils/MockAPIService.swift index be4bf47a2..f4d35b4b6 100644 --- a/Sources/TestUtils/MockAPIService.swift +++ b/Sources/TestUtils/MockAPIService.swift @@ -19,12 +19,16 @@ import Foundation import Networking -public struct MockAPIService: APIService { +public class MockAPIService: APIService { - public var apiResponse: Result + public var requestHandler: ((APIRequestV2) -> Result)! - public func fetch(request: Networking.APIRequestV2) async throws -> APIResponseV2 { - switch apiResponse { + public init(requestHandler: ((APIRequestV2) -> Result)? = nil) { + self.requestHandler = requestHandler + } + + public func fetch(request: APIRequestV2) async throws -> APIResponseV2 { + switch requestHandler!(request) { case .success(let result): return result case .failure(let error): diff --git a/Sources/UserScript/UserScript.swift b/Sources/UserScript/UserScript.swift index 3b35ddc42..728b3b36a 100644 --- a/Sources/UserScript/UserScript.swift +++ b/Sources/UserScript/UserScript.swift @@ -107,7 +107,7 @@ extension UserScript { } public func makeWKUserScript() async -> WKUserScriptBox { - let source = (try? await Task.detached { [source] in Self.prepareScriptSource(from: source) }.result.get())! + let source = await Task.detached { [source] in Self.prepareScriptSource(from: source) }.result.get() return await Self.makeWKUserScript(from: source, injectionTime: injectionTime, forMainFrameOnly: forMainFrameOnly, diff --git a/Tests/BrokenSitePromptTests/PrivacyConfigurationManagerMock.swift b/Tests/BrokenSitePromptTests/PrivacyConfigurationManagerMock.swift index 9a3d7610c..29adc74fb 100644 --- a/Tests/BrokenSitePromptTests/PrivacyConfigurationManagerMock.swift +++ b/Tests/BrokenSitePromptTests/PrivacyConfigurationManagerMock.swift @@ -59,7 +59,7 @@ class PrivacyConfigurationMock: PrivacyConfiguration { return enabledSubfeaturesForVersions[subfeature.rawValue]?.contains(versionProvider.appVersion() ?? "") ?? false } - func stateFor(_ subfeature: any PrivacySubfeature, versionProvider: BrowserServicesKit.AppVersionProvider, randomizer: (Range) -> Double) -> BrowserServicesKit.PrivacyConfigurationFeatureState { + func stateFor(_ subfeature: any PrivacySubfeature, versionProvider: AppVersionProvider, randomizer: (Range) -> Double) -> PrivacyConfigurationFeatureState { if isSubfeatureEnabled(subfeature, versionProvider: versionProvider, randomizer: randomizer) { return .enabled } @@ -98,6 +98,18 @@ class PrivacyConfigurationMock: PrivacyConfiguration { return userUnprotected.contains(domain ?? "") } + func stateFor(subfeatureID: BrowserServicesKit.SubfeatureID, parentFeatureID: BrowserServicesKit.ParentFeatureID, versionProvider: BrowserServicesKit.AppVersionProvider, randomizer: (Range) -> Double) -> BrowserServicesKit.PrivacyConfigurationFeatureState { + return .enabled + } + + func cohorts(for subfeature: any BrowserServicesKit.PrivacySubfeature) -> [BrowserServicesKit.PrivacyConfigurationData.Cohort]? { + return nil + } + + func cohorts(subfeatureID: BrowserServicesKit.SubfeatureID, parentFeatureID: BrowserServicesKit.ParentFeatureID) -> [BrowserServicesKit.PrivacyConfigurationData.Cohort]? { + return nil + } + } class PrivacyConfigurationManagerMock: PrivacyConfigurationManaging { diff --git a/Tests/BrowserServicesKitTests/ContentBlocker/ContentBlockerRulesManagerInitialCompilationTests.swift b/Tests/BrowserServicesKitTests/ContentBlocker/ContentBlockerRulesManagerInitialCompilationTests.swift index f6082a7b2..455734d9b 100644 --- a/Tests/BrowserServicesKitTests/ContentBlocker/ContentBlockerRulesManagerInitialCompilationTests.swift +++ b/Tests/BrowserServicesKitTests/ContentBlocker/ContentBlockerRulesManagerInitialCompilationTests.swift @@ -24,8 +24,9 @@ import TrackerRadarKit import BrowserServicesKit import WebKit import XCTest +import Common -final class CountedFulfillmentTestExpectation: XCTestExpectation { +final class CountedFulfillmentTestExpectation: XCTestExpectation, @unchecked Sendable { private(set) var currentFulfillmentCount: Int = 0 override func fulfill() { @@ -57,12 +58,24 @@ final class ContentBlockerRulesManagerInitialCompilationTests: XCTestCase { expStore.fulfill() } + let lookupAndFetchExp = self.expectation(description: "LRC should be missing") + let errorHandler = EventMapping { event, _, params, _ in + if case .contentBlockingLRCMissing = event { + lookupAndFetchExp.fulfill() + } else if case .contentBlockingCompilationTime = event { + XCTAssertNotNil(params?["compilationTime"]) + } else { + XCTFail("Unexpected event: \(event)") + } + } + let cbrm = ContentBlockerRulesManager(rulesSource: mockRulesSource, exceptionsSource: mockExceptionsSource, lastCompiledRulesStore: mockLastCompiledRulesStore, - updateListener: rulesUpdateListener) + updateListener: rulesUpdateListener, + errorReporting: errorHandler) - wait(for: [exp, expStore], timeout: 15.0) + wait(for: [exp, expStore, lookupAndFetchExp], timeout: 15.0) XCTAssertNotNil(mockLastCompiledRulesStore.rules) XCTAssertEqual(mockLastCompiledRulesStore.rules.first?.etag, mockRulesSource.trackerData?.etag) @@ -93,6 +106,8 @@ final class ContentBlockerRulesManagerInitialCompilationTests: XCTestCase { XCTFail("Should use rules cached by WebKit") } + let lookupAndFetchExp = self.expectation(description: "Should not fetch LRC") + // simulate the rules have been compiled in the past so the WKContentRuleListStore contains it _ = ContentBlockerRulesManager(rulesSource: mockRulesSource, exceptionsSource: mockExceptionsSource, @@ -103,15 +118,28 @@ final class ContentBlockerRulesManagerInitialCompilationTests: XCTestCase { rulesUpdateListener.onRulesUpdated = { rules in exp.fulfill() if exp.currentFulfillmentCount == 1 { // finished compilation after first installation + let errorHandler = EventMapping { event, _, params, _ in + if case .contentBlockingFetchLRCSucceeded = event { + XCTFail("Should not fetch LRC") + } else if case .contentBlockingLookupRulesSucceeded = event { + lookupAndFetchExp.fulfill() + } else if case .contentBlockingCompilationTime = event { + XCTAssertNotNil(params?["compilationTime"]) + } else { + XCTFail("Unexpected event: \(event)") + } + } + _ = ContentBlockerRulesManager(rulesSource: mockRulesSource, exceptionsSource: mockExceptionsSource, lastCompiledRulesStore: mockLastCompiledRulesStore, - updateListener: self.rulesUpdateListener) + updateListener: self.rulesUpdateListener, + errorReporting: errorHandler) } assertRules(rules) } - wait(for: [exp], timeout: 15.0) + wait(for: [exp, lookupAndFetchExp], timeout: 15.0) func assertRules(_ rules: [ContentBlockerRulesManager.Rules]) { guard let rules = rules.first else { XCTFail("Couldn't get rules"); return } @@ -178,12 +206,27 @@ final class ContentBlockerRulesManagerInitialCompilationTests: XCTestCase { XCTAssertEqual(newListName, rules.name) } + let lookupAndFetchExp = self.expectation(description: "Should not fetch LRC") + + let errorHandler = EventMapping { event, _, params, _ in + if case .contentBlockingFetchLRCSucceeded = event { + XCTFail("Should not fetch LRC") + } else if case .contentBlockingNoMatchInLRC = event { + lookupAndFetchExp.fulfill() + } else if case .contentBlockingCompilationTime = event { + XCTAssertNotNil(params?["compilationTime"]) + } else { + XCTFail("Unexpected event: \(event)") + } + } + _ = ContentBlockerRulesManager(rulesSource: mockRulesSource, exceptionsSource: mockExceptionsSource, lastCompiledRulesStore: mockLastCompiledRulesStore, - updateListener: rulesUpdateListener) + updateListener: rulesUpdateListener, + errorReporting: errorHandler) - wait(for: [expCacheLookup, expNext], timeout: 15.0) + wait(for: [expCacheLookup, expNext, lookupAndFetchExp], timeout: 15.0) } func testInitialCompilation_WhenThereAreChangesToTDS_ShouldBuildRulesUsingLastCompiledRulesAndScheduleRecompilationWithNewSource() { @@ -220,14 +263,26 @@ final class ContentBlockerRulesManagerInitialCompilationTests: XCTestCase { exceptionsSource: mockExceptionsSource, updateListener: rulesUpdateListener) + let lookupAndFetchExp = self.expectation(description: "Fetch LRC succeeded") let expOld = CountedFulfillmentTestExpectation(description: "Old Rules Compiled") rulesUpdateListener.onRulesUpdated = { _ in expOld.fulfill() + let errorHandler = EventMapping { event, _, params, _ in + if case .contentBlockingFetchLRCSucceeded = event { + lookupAndFetchExp.fulfill() + } else if case .contentBlockingCompilationTime = event { + XCTAssertNotNil(params?["compilationTime"]) + } else { + XCTFail("Unexpected event: \(event)") + } + } + _ = ContentBlockerRulesManager(rulesSource: mockUpdatedRulesSource, exceptionsSource: mockExceptionsSource, lastCompiledRulesStore: mockLastCompiledRulesStore, - updateListener: self.rulesUpdateListener) + updateListener: self.rulesUpdateListener, + errorReporting: errorHandler) } wait(for: [expOld], timeout: 15.0) @@ -237,27 +292,27 @@ final class ContentBlockerRulesManagerInitialCompilationTests: XCTestCase { expLastCompiledFetched.fulfill() } - let expRecompiled = CountedFulfillmentTestExpectation(description: "New Rules Compiled") - rulesUpdateListener.onRulesUpdated = { _ in - expRecompiled.fulfill() - } + let expRecompiled = CountedFulfillmentTestExpectation(description: "New Rules Compiled") + rulesUpdateListener.onRulesUpdated = { _ in + expRecompiled.fulfill() + + if expRecompiled.currentFulfillmentCount == 1 { // finished compilation after cold start (using last compiled rules) + mockLastCompiledRulesStore.onRulesGet = {} + XCTAssertEqual(mockLastCompiledRulesStore.rules.first?.etag, oldEtag) + XCTAssertEqual(mockLastCompiledRulesStore.rules.first?.name, mockRulesSource.ruleListName) + XCTAssertEqual(mockLastCompiledRulesStore.rules.first?.trackerData, mockRulesSource.trackerData?.tds) + XCTAssertEqual(mockLastCompiledRulesStore.rules.first?.identifier, oldIdentifier) + } else if expRecompiled.currentFulfillmentCount == 2 { // finished recompilation of rules due to changed tds + XCTAssertEqual(mockLastCompiledRulesStore.rules.first?.etag, updatedEtag) + XCTAssertEqual(mockLastCompiledRulesStore.rules.first?.name, mockRulesSource.ruleListName) + XCTAssertEqual(mockLastCompiledRulesStore.rules.first?.trackerData, mockRulesSource.trackerData?.tds) + XCTAssertEqual(mockLastCompiledRulesStore.rules.first?.identifier, newIdentifier) + } + } - if expRecompiled.currentFulfillmentCount == 1 { // finished compilation after cold start (using last compiled rules) - - mockLastCompiledRulesStore.onRulesGet = {} - XCTAssertEqual(mockLastCompiledRulesStore.rules.first?.etag, oldEtag) - XCTAssertEqual(mockLastCompiledRulesStore.rules.first?.name, mockRulesSource.ruleListName) - XCTAssertEqual(mockLastCompiledRulesStore.rules.first?.trackerData, mockRulesSource.trackerData?.tds) - XCTAssertEqual(mockLastCompiledRulesStore.rules.first?.identifier, oldIdentifier) - } else if expRecompiled.currentFulfillmentCount == 2 { // finished recompilation of rules due to changed tds - XCTAssertEqual(mockLastCompiledRulesStore.rules.first?.etag, updatedEtag) - XCTAssertEqual(mockLastCompiledRulesStore.rules.first?.name, mockRulesSource.ruleListName) - XCTAssertEqual(mockLastCompiledRulesStore.rules.first?.trackerData, mockRulesSource.trackerData?.tds) - XCTAssertEqual(mockLastCompiledRulesStore.rules.first?.identifier, newIdentifier) - } + wait(for: [expLastCompiledFetched, expRecompiled, lookupAndFetchExp], timeout: 15.0) - wait(for: [expLastCompiledFetched, expRecompiled], timeout: 15.0) - } + } struct MockLastCompiledRules: LastCompiledRules { diff --git a/Tests/BrowserServicesKitTests/ContentBlocker/ContentBlockerRulesManagerTests.swift b/Tests/BrowserServicesKitTests/ContentBlocker/ContentBlockerRulesManagerTests.swift index 60df0f3a1..15321e1ba 100644 --- a/Tests/BrowserServicesKitTests/ContentBlocker/ContentBlockerRulesManagerTests.swift +++ b/Tests/BrowserServicesKitTests/ContentBlocker/ContentBlockerRulesManagerTests.swift @@ -203,6 +203,8 @@ class ContentBlockerRulesManagerLoadingTests: ContentBlockerRulesManagerTests { let errorExp = expectation(description: "No error reported") errorExp.isInverted = true + + let lookupAndFetchExp = expectation(description: "Look and Fetch rules failed") let compilationTimeExp = expectation(description: "Compilation Time reported") let errorHandler = EventMapping { event, _, params, _ in if case .contentBlockingCompilationFailed(let listName, let component) = event { @@ -217,6 +219,8 @@ class ContentBlockerRulesManagerLoadingTests: ContentBlockerRulesManagerTests { } else if case .contentBlockingCompilationTime = event { XCTAssertNotNil(params?["compilationTime"]) compilationTimeExp.fulfill() + } else if case .contentBlockingLRCMissing = event { + lookupAndFetchExp.fulfill() } else { XCTFail("Unexpected event: \(event)") } @@ -227,7 +231,7 @@ class ContentBlockerRulesManagerLoadingTests: ContentBlockerRulesManagerTests { updateListener: rulesUpdateListener, errorReporting: errorHandler) - wait(for: [exp, errorExp, compilationTimeExp], timeout: 15.0) + wait(for: [exp, errorExp, compilationTimeExp, lookupAndFetchExp], timeout: 15.0) XCTAssertNotNil(cbrm.currentRules) XCTAssertEqual(cbrm.currentRules.first?.etag, mockRulesSource.trackerData?.etag) @@ -254,6 +258,8 @@ class ContentBlockerRulesManagerLoadingTests: ContentBlockerRulesManagerTests { } let errorExp = expectation(description: "Error reported") + let lookupAndFetchExp = expectation(description: "Look and Fetch rules failed") + let errorHandler = EventMapping { event, _, params, _ in if case .contentBlockingCompilationFailed(let listName, let component) = event { XCTAssertEqual(listName, DefaultContentBlockerRulesListsSource.Constants.trackerDataSetRulesListName) @@ -266,6 +272,8 @@ class ContentBlockerRulesManagerLoadingTests: ContentBlockerRulesManagerTests { } else if case .contentBlockingCompilationTime = event { XCTAssertNotNil(params?["compilationTime"]) + } else if case .contentBlockingLRCMissing = event { + lookupAndFetchExp.fulfill() } else { XCTFail("Unexpected event: \(event)") } @@ -276,7 +284,7 @@ class ContentBlockerRulesManagerLoadingTests: ContentBlockerRulesManagerTests { updateListener: rulesUpdateListener, errorReporting: errorHandler) - wait(for: [exp, errorExp], timeout: 15.0) + wait(for: [exp, errorExp, lookupAndFetchExp], timeout: 15.0) XCTAssertNotNil(cbrm.currentRules) XCTAssertEqual(cbrm.currentRules.first?.etag, mockRulesSource.embeddedTrackerData.etag) @@ -539,6 +547,9 @@ class ContentBlockerRulesManagerLoadingTests: ContentBlockerRulesManagerTests { let errorExp = expectation(description: "Error reported") errorExp.expectedFulfillmentCount = 5 + + let lookupAndFetchExp = expectation(description: "Look and Fetch rules failed") + var errorEvents = [ContentBlockerDebugEvents.Component]() let errorHandler = EventMapping { event, _, params, _ in if case .contentBlockingCompilationFailed(let listName, let component) = event { @@ -554,6 +565,8 @@ class ContentBlockerRulesManagerLoadingTests: ContentBlockerRulesManagerTests { } else if case .contentBlockingCompilationTime = event { XCTAssertNotNil(params?["compilationTime"]) errorExp.fulfill() + } else if case .contentBlockingLRCMissing = event { + lookupAndFetchExp.fulfill() } else { XCTFail("Unexpected event: \(event)") } @@ -564,7 +577,7 @@ class ContentBlockerRulesManagerLoadingTests: ContentBlockerRulesManagerTests { updateListener: rulesUpdateListener, errorReporting: errorHandler) - wait(for: [exp, errorExp], timeout: 15.0) + wait(for: [exp, errorExp, lookupAndFetchExp], timeout: 15.0) XCTAssertEqual(Set(errorEvents), Set([.tds, .tempUnprotected, @@ -619,6 +632,9 @@ class ContentBlockerRulesManagerLoadingTests: ContentBlockerRulesManagerTests { let errorExp = expectation(description: "Error reported") errorExp.expectedFulfillmentCount = 4 + + let lookupAndFetchExp = expectation(description: "Look and Fetch rules failed") + var errorEvents = [ContentBlockerDebugEvents.Component]() let errorHandler = EventMapping { event, _, params, _ in if case .contentBlockingCompilationFailed(let listName, let component) = event { @@ -634,7 +650,10 @@ class ContentBlockerRulesManagerLoadingTests: ContentBlockerRulesManagerTests { } else if case .contentBlockingCompilationTime = event { XCTAssertNotNil(params?["compilationTime"]) errorExp.fulfill() - } else { + } else if case .contentBlockingLRCMissing = event { + lookupAndFetchExp.fulfill() + } else + { XCTFail("Unexpected event: \(event)") } } @@ -644,7 +663,7 @@ class ContentBlockerRulesManagerLoadingTests: ContentBlockerRulesManagerTests { updateListener: rulesUpdateListener, errorReporting: errorHandler) - wait(for: [exp, errorExp], timeout: 15.0) + wait(for: [exp, errorExp, lookupAndFetchExp], timeout: 15.0) XCTAssertEqual(Set(errorEvents), Set([.tempUnprotected, .allowlist, diff --git a/Tests/BrowserServicesKitTests/ContentBlocker/UserContentControllerTests.swift b/Tests/BrowserServicesKitTests/ContentBlocker/UserContentControllerTests.swift index e1cb475cd..767250149 100644 --- a/Tests/BrowserServicesKitTests/ContentBlocker/UserContentControllerTests.swift +++ b/Tests/BrowserServicesKitTests/ContentBlocker/UserContentControllerTests.swift @@ -318,18 +318,26 @@ class PrivacyConfigurationMock: PrivacyConfiguration { return .enabled } - func isSubfeatureEnabled( - _ subfeature: any PrivacySubfeature, - versionProvider: AppVersionProvider, - randomizer: (Range) -> Double - ) -> Bool { + func isSubfeatureEnabled(_ subfeature: any BrowserServicesKit.PrivacySubfeature, versionProvider: BrowserServicesKit.AppVersionProvider, randomizer: (Range) -> Double) -> Bool { true } - func stateFor(_ subfeature: any PrivacySubfeature, versionProvider: AppVersionProvider, randomizer: (Range) -> Double) -> PrivacyConfigurationFeatureState { + func stateFor(_ subfeature: any BrowserServicesKit.PrivacySubfeature, versionProvider: BrowserServicesKit.AppVersionProvider, randomizer: (Range) -> Double) -> BrowserServicesKit.PrivacyConfigurationFeatureState { return .enabled } + func stateFor(subfeatureID: BrowserServicesKit.SubfeatureID, parentFeatureID: BrowserServicesKit.ParentFeatureID, versionProvider: BrowserServicesKit.AppVersionProvider, randomizer: (Range) -> Double) -> BrowserServicesKit.PrivacyConfigurationFeatureState { + return .enabled + } + + func cohorts(for subfeature: any BrowserServicesKit.PrivacySubfeature) -> [BrowserServicesKit.PrivacyConfigurationData.Cohort]? { + return nil + } + + func cohorts(subfeatureID: BrowserServicesKit.SubfeatureID, parentFeatureID: BrowserServicesKit.ParentFeatureID) -> [BrowserServicesKit.PrivacyConfigurationData.Cohort]? { + return nil + } + var identifier: String = "abcd" var version: String? = "123456789" var userUnprotectedDomains: [String] = [] diff --git a/Tests/BrowserServicesKitTests/ContentBlocker/WebViewTestHelper.swift b/Tests/BrowserServicesKitTests/ContentBlocker/WebViewTestHelper.swift index 31370ce4a..bc979233e 100644 --- a/Tests/BrowserServicesKitTests/ContentBlocker/WebViewTestHelper.swift +++ b/Tests/BrowserServicesKitTests/ContentBlocker/WebViewTestHelper.swift @@ -225,3 +225,11 @@ final class WebKitTestHelper { } } } + +class MockExperimentCohortsManager: ExperimentCohortsManaging { + func resolveCohort(for experiment: ExperimentSubfeature, allowCohortReassignment: Bool) -> CohortID? { + return nil + } + + var experiments: Experiments? +} diff --git a/Tests/BrowserServicesKitTests/FeatureFlagging/DefaultFeatureFlaggerTests.swift b/Tests/BrowserServicesKitTests/FeatureFlagging/DefaultFeatureFlaggerTests.swift index 3eb2db17e..6f5fffbd4 100644 --- a/Tests/BrowserServicesKitTests/FeatureFlagging/DefaultFeatureFlaggerTests.swift +++ b/Tests/BrowserServicesKitTests/FeatureFlagging/DefaultFeatureFlaggerTests.swift @@ -52,15 +52,18 @@ final class CapturingFeatureFlagOverriding: FeatureFlagLocalOverriding { final class DefaultFeatureFlaggerTests: XCTestCase { var internalUserDeciderStore: MockInternalUserStoring! + var experimentManager: MockExperimentManager! var overrides: CapturingFeatureFlagOverriding! override func setUp() { super.setUp() internalUserDeciderStore = MockInternalUserStoring() + experimentManager = MockExperimentManager() } override func tearDown() { internalUserDeciderStore = nil + experimentManager = nil super.tearDown() } @@ -72,9 +75,9 @@ final class DefaultFeatureFlaggerTests: XCTestCase { func testWhenInternalOnly_returnsIsInternalUserValue() { let featureFlagger = createFeatureFlagger() internalUserDeciderStore.isInternalUser = false - XCTAssertFalse(featureFlagger.isFeatureOn(for: FeatureFlagSource.internalOnly)) + XCTAssertFalse(featureFlagger.isFeatureOn(for: FeatureFlagSource.internalOnly())) internalUserDeciderStore.isInternalUser = true - XCTAssertTrue(featureFlagger.isFeatureOn(for: FeatureFlagSource.internalOnly)) + XCTAssertTrue(featureFlagger.isFeatureOn(for: FeatureFlagSource.internalOnly())) } func testWhenRemoteDevelopment_isNOTInternalUser_returnsFalse() { @@ -141,6 +144,128 @@ final class DefaultFeatureFlaggerTests: XCTestCase { assertFeatureFlagger(with: embeddedData, willReturn: false, for: sourceProvider) } + // MARK: - Experiments + + func testWhenGetCohortIfEnabled_andSourceDisabled_returnsNil() { + let featureFlagger = createFeatureFlagger() + let flag = FakeExperimentFlag(source: .disabled) + let cohort = featureFlagger.getCohortIfEnabled(for: flag) + XCTAssertNil(cohort) + } + + func estWhenGetCohortIfEnabled_andSourceInternal_returnsPassedCohort() { + let featureFlagger = createFeatureFlagger() + let flag = FakeExperimentFlag(source: .internalOnly(AutofillCohort.blue)) + let cohort = featureFlagger.getCohortIfEnabled(for: flag) + XCTAssertEqual(cohort?.rawValue, AutofillCohort.blue.rawValue) + } + + func testWhenGetCohortIfEnabled_andRemoteInternal_andInternalStateTrue_and_cohortAssigned_returnsAssignedCohort() { + internalUserDeciderStore.isInternalUser = true + let subfeature = AutofillSubfeature.credentialsAutofill + experimentManager.cohortToReturn = AutofillCohort.control.rawValue + let embeddedData = Self.embeddedConfig(autofillSubfeatureForState: (subfeature: subfeature, state: "enabled")) + + let flag = FakeExperimentFlag(source: .remoteDevelopment(.subfeature(AutofillSubfeature.credentialsAutofill))) + let featureFlagger = createFeatureFlagger(withMockedConfigData: embeddedData) + let cohort = featureFlagger.getCohortIfEnabled(for: flag) + XCTAssertEqual(cohort?.rawValue, AutofillCohort.control.rawValue) + } + + func testWhenGetCohortIfEnabled_andRemoteInternal_andInternalStateFalse_and_cohortAssigned_returnsNil() { + internalUserDeciderStore.isInternalUser = false + let subfeature = AutofillSubfeature.credentialsAutofill + experimentManager.cohortToReturn = AutofillCohort.control.rawValue + let embeddedData = Self.embeddedConfig(autofillSubfeatureForState: (subfeature: subfeature, state: "enabled")) + + let flag = FakeExperimentFlag(source: .remoteDevelopment(.subfeature(AutofillSubfeature.credentialsAutofill))) + let featureFlagger = createFeatureFlagger(withMockedConfigData: embeddedData) + let cohort = featureFlagger.getCohortIfEnabled(for: flag) + XCTAssertNil(cohort) + } + + func testWhenGetCohortIfEnabled_andRemoteInternal_andInternalStateTrue_and_cohortAssigned_andFeaturePassed_returnsNil() { + internalUserDeciderStore.isInternalUser = true + let subfeature = AutofillSubfeature.credentialsAutofill + experimentManager.cohortToReturn = AutofillCohort.control.rawValue + let embeddedData = Self.embeddedConfig(autofillSubfeatureForState: (subfeature: subfeature, state: "enabled")) + + let flag = FakeExperimentFlag(source: .remoteDevelopment(.feature(.autofill))) + let featureFlagger = createFeatureFlagger(withMockedConfigData: embeddedData) + let cohort = featureFlagger.getCohortIfEnabled(for: flag) + XCTAssertNil(cohort) + } + + func testWhenGetCohortIfEnabled_andRemoteInternal_andInternalStateTrue_and_cohortNotAssigned_returnsNil() { + internalUserDeciderStore.isInternalUser = true + let subfeature = AutofillSubfeature.credentialsAutofill + experimentManager.cohortToReturn = nil + let embeddedData = Self.embeddedConfig(autofillSubfeatureForState: (subfeature: subfeature, state: "enabled")) + + let flag = FakeExperimentFlag(source: .remoteDevelopment(.subfeature(AutofillSubfeature.credentialsAutofill))) + let featureFlagger = createFeatureFlagger(withMockedConfigData: embeddedData) + let cohort = featureFlagger.getCohortIfEnabled(for: flag) + XCTAssertNil(cohort) + } + + func testWhenGetCohortIfEnabled_andRemoteInternal_andInternalStateTrue_and_cohortAssignedButNorMatchingEnum_returnsNil() { + internalUserDeciderStore.isInternalUser = true + let subfeature = AutofillSubfeature.credentialsAutofill + experimentManager.cohortToReturn = "some" + let embeddedData = Self.embeddedConfig(autofillSubfeatureForState: (subfeature: subfeature, state: "enabled")) + + let flag = FakeExperimentFlag(source: .remoteDevelopment(.subfeature(AutofillSubfeature.credentialsAutofill))) + let featureFlagger = createFeatureFlagger(withMockedConfigData: embeddedData) + let cohort = featureFlagger.getCohortIfEnabled(for: flag) + XCTAssertNil(cohort) + } + + func testWhenGetCohortIfEnabled_andRemoteReleasable_and_cohortAssigned_returnsAssignedCohort() { + let subfeature = AutofillSubfeature.credentialsAutofill + experimentManager.cohortToReturn = AutofillCohort.control.rawValue + let embeddedData = Self.embeddedConfig(autofillSubfeatureForState: (subfeature: subfeature, state: "enabled")) + + let flag = FakeExperimentFlag(source: .remoteReleasable(.subfeature(AutofillSubfeature.credentialsAutofill))) + let featureFlagger = createFeatureFlagger(withMockedConfigData: embeddedData) + let cohort = featureFlagger.getCohortIfEnabled(for: flag) + XCTAssertEqual(cohort?.rawValue, AutofillCohort.control.rawValue) + } + + func testWhenGetCohortIfEnabled_andRemoteReleasable_and_cohortAssigned_andFeaturePassed_returnsNil() { + let subfeature = AutofillSubfeature.credentialsAutofill + experimentManager.cohortToReturn = AutofillCohort.control.rawValue + let embeddedData = Self.embeddedConfig(autofillSubfeatureForState: (subfeature: subfeature, state: "enabled")) + + let flag = FakeExperimentFlag(source: .remoteReleasable(.feature(.autofill))) + let featureFlagger = createFeatureFlagger(withMockedConfigData: embeddedData) + let cohort = featureFlagger.getCohortIfEnabled(for: flag) + XCTAssertNil(cohort) + } + + func testWhenGetCohortIfEnabled_andRemoteReleasable_and_cohortNotAssigned_andFeaturePassed_returnsNil() { + internalUserDeciderStore.isInternalUser = true + let subfeature = AutofillSubfeature.credentialsAutofill + experimentManager.cohortToReturn = nil + let embeddedData = Self.embeddedConfig(autofillSubfeatureForState: (subfeature: subfeature, state: "enabled")) + + let flag = FakeExperimentFlag(source: .remoteReleasable(.subfeature(AutofillSubfeature.credentialsAutofill))) + let featureFlagger = createFeatureFlagger(withMockedConfigData: embeddedData) + let cohort = featureFlagger.getCohortIfEnabled(for: flag) + XCTAssertNil(cohort) + } + + func testWhenGetCohortIfEnabled_andRemoteReleasable_and_cohortAssignedButNotMatchingEnum_returnsNil() { + internalUserDeciderStore.isInternalUser = true + let subfeature = AutofillSubfeature.credentialsAutofill + experimentManager.cohortToReturn = "some" + let embeddedData = Self.embeddedConfig(autofillSubfeatureForState: (subfeature: subfeature, state: "enabled")) + + let flag = FakeExperimentFlag(source: .remoteReleasable(.subfeature(AutofillSubfeature.credentialsAutofill))) + let featureFlagger = createFeatureFlagger(withMockedConfigData: embeddedData) + let cohort = featureFlagger.getCohortIfEnabled(for: flag) + XCTAssertNil(cohort) + } + // MARK: - Overrides func testWhenFeatureFlaggerIsInitializedWithLocalOverridesAndUserIsNotInternalThenAllFlagsAreCleared() throws { @@ -186,7 +311,7 @@ final class DefaultFeatureFlaggerTests: XCTestCase { localProtection: MockDomainsProtectionStore(), internalUserDecider: DefaultInternalUserDecider()) let internalUserDecider = DefaultInternalUserDecider(store: internalUserDeciderStore) - return DefaultFeatureFlagger(internalUserDecider: internalUserDecider, privacyConfigManager: manager) + return DefaultFeatureFlagger(internalUserDecider: internalUserDecider, privacyConfigManager: manager, experimentManager: experimentManager) } private func createFeatureFlaggerWithLocalOverrides(withMockedConfigData data: Data = DefaultFeatureFlaggerTests.embeddedConfig()) -> DefaultFeatureFlagger { @@ -203,6 +328,7 @@ final class DefaultFeatureFlaggerTests: XCTestCase { internalUserDecider: internalUserDecider, privacyConfigManager: manager, localOverrides: overrides, + experimentManager: nil, for: TestFeatureFlag.self ) } @@ -243,3 +369,25 @@ extension FeatureFlagSource: FeatureFlagDescribing { public var rawValue: String { "rawValue" } public var source: FeatureFlagSource { self } } + +class MockExperimentManager: ExperimentCohortsManaging { + var cohortToReturn: CohortID? + var experiments: BrowserServicesKit.Experiments? + + func resolveCohort(for experiment: BrowserServicesKit.ExperimentSubfeature, allowCohortReassignment: Bool) -> CohortID? { + return cohortToReturn + } +} + +private struct FakeExperimentFlag: FeatureFlagExperimentDescribing { + typealias CohortType = AutofillCohort + + var rawValue: String = "fake-experiment" + + var source: FeatureFlagSource +} + +private enum AutofillCohort: String, FlagCohort { + case control + case blue +} diff --git a/Tests/BrowserServicesKitTests/FeatureFlagging/ExperimentCohortsManagerTests.swift b/Tests/BrowserServicesKitTests/FeatureFlagging/ExperimentCohortsManagerTests.swift new file mode 100644 index 000000000..48fc85355 --- /dev/null +++ b/Tests/BrowserServicesKitTests/FeatureFlagging/ExperimentCohortsManagerTests.swift @@ -0,0 +1,189 @@ +// +// ExperimentCohortsManagerTests.swift +// +// Copyright © 2024 DuckDuckGo. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +import XCTest +@testable import BrowserServicesKit + +final class ExperimentCohortsManagerTests: XCTestCase { + + let cohort1 = PrivacyConfigurationData.Cohort(json: ["name": "Cohort1", "weight": 1])! + let cohort2 = PrivacyConfigurationData.Cohort(json: ["name": "Cohort2", "weight": 0])! + let cohort3 = PrivacyConfigurationData.Cohort(json: ["name": "Cohort3", "weight": 2])! + let cohort4 = PrivacyConfigurationData.Cohort(json: ["name": "Cohort4", "weight": 0])! + + var mockStore: MockExperimentDataStore! + var experimentCohortsManager: ExperimentCohortsManager! + + let subfeatureName1 = "TestSubfeature1" + var experimentData1: ExperimentData! + + let subfeatureName2 = "TestSubfeature2" + var experimentData2: ExperimentData! + + let subfeatureName3 = "TestSubfeature3" + var experimentData3: ExperimentData! + + let subfeatureName4 = "TestSubfeature4" + var experimentData4: ExperimentData! + + let encoder: JSONEncoder = { + let encoder = JSONEncoder() + encoder.dateEncodingStrategy = .secondsSince1970 + return encoder + }() + + override func setUp() { + super.setUp() + mockStore = MockExperimentDataStore() + experimentCohortsManager = ExperimentCohortsManager( + store: mockStore + ) + + let expectedDate1 = Date() + experimentData1 = ExperimentData(parentID: "TestParent", cohortID: cohort1.name, enrollmentDate: expectedDate1) + + let expectedDate2 = Date().addingTimeInterval(60) + experimentData2 = ExperimentData(parentID: "TestParent", cohortID: cohort2.name, enrollmentDate: expectedDate2) + + let expectedDate3 = Date() + experimentData3 = ExperimentData(parentID: "TestParent", cohortID: cohort3.name, enrollmentDate: expectedDate3) + + let expectedDate4 = Date().addingTimeInterval(60) + experimentData4 = ExperimentData(parentID: "TestParent", cohortID: cohort4.name, enrollmentDate: expectedDate4) + } + + override func tearDown() { + mockStore = nil + experimentCohortsManager = nil + experimentData1 = nil + experimentData2 = nil + super.tearDown() + } + + func testExperimentReturnAssignedExperiments() { + // GIVEN + mockStore.experiments = [subfeatureName1: experimentData1, subfeatureName2: experimentData2] + + // WHEN + let experiments = experimentCohortsManager.experiments + + // THEN + XCTAssertEqual(experiments?.count, 2) + XCTAssertEqual(experiments?[subfeatureName1], experimentData1) + XCTAssertEqual(experiments?[subfeatureName2], experimentData2) + XCTAssertNil(experiments?[subfeatureName3]) + } + + func testCohortReturnsCohortIDIfExistsForMultipleSubfeatures() { + // GIVEN + mockStore.experiments = [subfeatureName1: experimentData1, subfeatureName2: experimentData2] + + // WHEN + let result1 = experimentCohortsManager.resolveCohort(for: ExperimentSubfeature(parentID: experimentData1.parentID, subfeatureID: subfeatureName1, cohorts: [cohort1, cohort2]), allowCohortReassignment: false) + let result2 = experimentCohortsManager.resolveCohort(for: ExperimentSubfeature(parentID: experimentData2.parentID, subfeatureID: subfeatureName2, cohorts: [cohort2, cohort3]), allowCohortReassignment: false) + + // THEN + XCTAssertEqual(result1, experimentData1.cohortID) + XCTAssertEqual(result2, experimentData2.cohortID) + } + + func testCohortAssignIfEnabledWhenNoCohortExists() { + // GIVEN + mockStore.experiments = [:] + let cohorts = [cohort1, cohort2] + let experiment = ExperimentSubfeature(parentID: experimentData1.parentID, subfeatureID: subfeatureName1, cohorts: cohorts) + + // WHEN + let result = experimentCohortsManager.resolveCohort(for: experiment, allowCohortReassignment: true) + + // THEN + XCTAssertNotNil(result) + XCTAssertEqual(result, experimentData1.cohortID) + } + + func testCohortDoesNotAssignIfAssignIfEnabledIsFalse() { + // GIVEN + mockStore.experiments = [:] + let cohorts = [cohort1, cohort2] + let experiment = ExperimentSubfeature(parentID: experimentData1.parentID, subfeatureID: subfeatureName1, cohorts: cohorts) + + // WHEN + let result = experimentCohortsManager.resolveCohort(for: experiment, allowCohortReassignment: false) + + // THEN + XCTAssertNil(result) + } + + func testCohortDoesNotAssignIfAssignIfEnabledIsTrueButNoCohortsAvailable() { + // GIVEN + mockStore.experiments = [:] + let experiment = ExperimentSubfeature(parentID: "TestParent", subfeatureID: "NonExistentSubfeature", cohorts: []) + + // WHEN + let result = experimentCohortsManager.resolveCohort(for: experiment, allowCohortReassignment: true) + + // THEN + XCTAssertNil(result) + } + + func testCohortReassignsCohortIfAssignedCohortDoesNotExistAndAssignIfEnabledIsTrue() { + // GIVEN + mockStore.experiments = [subfeatureName1: experimentData1] + + // WHEN + let result1 = experimentCohortsManager.resolveCohort(for: ExperimentSubfeature(parentID: experimentData1.parentID, subfeatureID: subfeatureName1, cohorts: [cohort2, cohort3]), allowCohortReassignment: true) + + // THEN + XCTAssertEqual(result1, experimentData3.cohortID) + } + + func testCohortDoesNotReassignsCohortIfAssignedCohortDoesNotExistAndAssignIfEnabledIsTrue() { + // GIVEN + mockStore.experiments = [subfeatureName1: experimentData1] + + // WHEN + let result1 = experimentCohortsManager.resolveCohort(for: ExperimentSubfeature(parentID: experimentData1.parentID, subfeatureID: subfeatureName1, cohorts: [cohort2, cohort3]), allowCohortReassignment: false) + + // THEN + XCTAssertNil(result1) + } + + func testCohortAssignsBasedOnWeight() { + // GIVEN + let experiment = ExperimentSubfeature(parentID: experimentData3.parentID, subfeatureID: subfeatureName3, cohorts: [cohort3, cohort4]) + + let randomizer: (Range) -> Double = { range in + return 1.5 + } + + experimentCohortsManager = ExperimentCohortsManager( + store: mockStore, + randomizer: randomizer + ) + + // WHEN + let result = experimentCohortsManager.resolveCohort(for: experiment, allowCohortReassignment: true) + + // THEN + XCTAssertEqual(result, experimentData3.cohortID) + } +} + +class MockExperimentDataStore: ExperimentsDataStoring { + var experiments: Experiments? +} diff --git a/Tests/BrowserServicesKitTests/PrivacyConfig/ExperimentsDataStoreTests.swift b/Tests/BrowserServicesKitTests/FeatureFlagging/ExperimentsDataStoreTests.swift similarity index 80% rename from Tests/BrowserServicesKitTests/PrivacyConfig/ExperimentsDataStoreTests.swift rename to Tests/BrowserServicesKitTests/FeatureFlagging/ExperimentsDataStoreTests.swift index 0466155b0..77da51663 100644 --- a/Tests/BrowserServicesKitTests/PrivacyConfig/ExperimentsDataStoreTests.swift +++ b/Tests/BrowserServicesKitTests/FeatureFlagging/ExperimentsDataStoreTests.swift @@ -47,8 +47,8 @@ final class ExperimentsDataStoreTests: XCTestCase { func testExperimentsGetReturnsDecodedExperiments() { // GIVEN - let experimentData1 = ExperimentData(cohort: "TestCohort1", enrollmentDate: Date()) - let experimentData2 = ExperimentData(cohort: "TestCohort2", enrollmentDate: Date()) + let experimentData1 = ExperimentData(parentID: "parent", cohortID: "TestCohort1", enrollmentDate: Date()) + let experimentData2 = ExperimentData(parentID: "parent", cohortID: "TestCohort2", enrollmentDate: Date()) let experiments = [subfeatureName1: experimentData1, subfeatureName2: experimentData2] let encoder = JSONEncoder() @@ -62,17 +62,17 @@ final class ExperimentsDataStoreTests: XCTestCase { // THEN let timeDifference1 = abs(experimentData1.enrollmentDate.timeIntervalSince(result?[subfeatureName1]?.enrollmentDate ?? Date())) let timeDifference2 = abs(experimentData2.enrollmentDate.timeIntervalSince(result?[subfeatureName2]?.enrollmentDate ?? Date())) - XCTAssertEqual(result?[subfeatureName1]?.cohort, experimentData1.cohort) + XCTAssertEqual(result?[subfeatureName1]?.cohortID, experimentData1.cohortID) XCTAssertLessThanOrEqual(timeDifference1, 1.0) - XCTAssertEqual(result?[subfeatureName2]?.cohort, experimentData2.cohort) + XCTAssertEqual(result?[subfeatureName2]?.cohortID, experimentData2.cohortID) XCTAssertLessThanOrEqual(timeDifference2, 1.0) } func testExperimentsSetEncodesAndStoresData() throws { // GIVEN - let experimentData1 = ExperimentData(cohort: "TestCohort1", enrollmentDate: Date()) - let experimentData2 = ExperimentData(cohort: "TestCohort2", enrollmentDate: Date()) + let experimentData1 = ExperimentData(parentID: "parent", cohortID: "TestCohort1", enrollmentDate: Date()) + let experimentData2 = ExperimentData(parentID: "parent2", cohortID: "TestCohort2", enrollmentDate: Date()) let experiments = [subfeatureName1: experimentData1, subfeatureName2: experimentData2] // WHEN @@ -85,9 +85,11 @@ final class ExperimentsDataStoreTests: XCTestCase { let decodedExperiments = try? decoder.decode(Experiments.self, from: storedData) let timeDifference1 = abs(experimentData1.enrollmentDate.timeIntervalSince(decodedExperiments?[subfeatureName1]?.enrollmentDate ?? Date())) let timeDifference2 = abs(experimentData2.enrollmentDate.timeIntervalSince(decodedExperiments?[subfeatureName2]?.enrollmentDate ?? Date())) - XCTAssertEqual(decodedExperiments?[subfeatureName1]?.cohort, experimentData1.cohort) + XCTAssertEqual(decodedExperiments?[subfeatureName1]?.cohortID, experimentData1.cohortID) + XCTAssertEqual(decodedExperiments?[subfeatureName1]?.parentID, experimentData1.parentID) XCTAssertLessThanOrEqual(timeDifference1, 1.0) - XCTAssertEqual(decodedExperiments?[subfeatureName2]?.cohort, experimentData2.cohort) + XCTAssertEqual(decodedExperiments?[subfeatureName2]?.cohortID, experimentData2.cohortID) + XCTAssertEqual(decodedExperiments?[subfeatureName2]?.parentID, experimentData2.parentID) XCTAssertLessThanOrEqual(timeDifference2, 1.0) } } diff --git a/Tests/BrowserServicesKitTests/FeatureFlagging/FeatureFlagLocalOverridesTests.swift b/Tests/BrowserServicesKitTests/FeatureFlagging/FeatureFlagLocalOverridesTests.swift index 5e2407ca1..5223ff059 100644 --- a/Tests/BrowserServicesKitTests/FeatureFlagging/FeatureFlagLocalOverridesTests.swift +++ b/Tests/BrowserServicesKitTests/FeatureFlagging/FeatureFlagLocalOverridesTests.swift @@ -46,7 +46,7 @@ final class FeatureFlagLocalOverridesTests: XCTestCase { let internalUserDecider = DefaultInternalUserDecider(store: internalUserDeciderStore) let privacyConfig = MockPrivacyConfiguration() let privacyConfigManager = MockPrivacyConfigurationManager(privacyConfig: privacyConfig, internalUserDecider: internalUserDecider) - featureFlagger = DefaultFeatureFlagger(internalUserDecider: internalUserDecider, privacyConfigManager: privacyConfigManager) + featureFlagger = DefaultFeatureFlagger(internalUserDecider: internalUserDecider, privacyConfigManager: privacyConfigManager, experimentManager: nil) keyValueStore = MockKeyValueStore() actionHandler = CapturingFeatureFlagLocalOverridesHandler() diff --git a/Tests/BrowserServicesKitTests/FeatureFlagging/FeatureFlaggerExperimentsTests.swift b/Tests/BrowserServicesKitTests/FeatureFlagging/FeatureFlaggerExperimentsTests.swift new file mode 100644 index 000000000..b6931ef22 --- /dev/null +++ b/Tests/BrowserServicesKitTests/FeatureFlagging/FeatureFlaggerExperimentsTests.swift @@ -0,0 +1,1228 @@ +// +// FeatureFlaggerExperimentsTests.swift +// +// Copyright © 2024 DuckDuckGo. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +import XCTest +@testable import BrowserServicesKit + +struct CredentialsSavingFlag: FeatureFlagExperimentDescribing { + + typealias CohortType = Cohort + + var rawValue = "credentialSaving" + + var source: FeatureFlagSource = .remoteReleasable(.subfeature(AutofillSubfeature.credentialsSaving)) + + enum Cohort: String, FlagCohort { + case control + case blue + case red + } +} + +struct InlineIconCredentialsFlag: FeatureFlagExperimentDescribing { + + typealias CohortType = Cohort + + var rawValue = "inlineIconCredentials" + + var source: FeatureFlagSource = .remoteReleasable(.subfeature(AutofillSubfeature.inlineIconCredentials)) + + enum Cohort: String, FlagCohort { + case control + case blue + case green + } +} + +struct AccessCredentialManagementFlag: FeatureFlagExperimentDescribing { + + typealias CohortType = Cohort + + var rawValue = "accessCredentialManagement" + + var source: FeatureFlagSource = .remoteReleasable(.subfeature(AutofillSubfeature.accessCredentialManagement)) + + enum Cohort: String, FlagCohort { + case control + case blue + case green + } +} + +final class FeatureFlaggerExperimentsTests: XCTestCase { + + var featureJson: Data = "{}".data(using: .utf8)! + var mockEmbeddedData: MockEmbeddedDataProvider! + var mockStore: MockExperimentDataStore! + var experimentManager: ExperimentCohortsManager! + var manager: PrivacyConfigurationManager! + var locale: Locale! + var featureFlagger: FeatureFlagger! + + let subfeatureName = "credentialsSaving" + + override func setUp() { + locale = Locale(identifier: "fr_US") + mockEmbeddedData = MockEmbeddedDataProvider(data: featureJson, etag: "test") + let mockInternalUserStore = MockInternalUserStoring() + mockStore = MockExperimentDataStore() + experimentManager = ExperimentCohortsManager(store: mockStore) + manager = PrivacyConfigurationManager(fetchedETag: nil, + fetchedData: nil, + embeddedDataProvider: mockEmbeddedData, + localProtection: MockDomainsProtectionStore(), + internalUserDecider: DefaultInternalUserDecider(store: mockInternalUserStore), + locale: locale) + featureFlagger = DefaultFeatureFlagger(internalUserDecider: DefaultInternalUserDecider(store: mockInternalUserStore), privacyConfigManager: manager, experimentManager: experimentManager) + } + + override func tearDown() { + featureJson = "".data(using: .utf8)! + mockEmbeddedData = nil + mockStore = nil + experimentManager = nil + manager = nil + } + + func testCohortOnlyAssignedWhenCallingStateForSubfeature() { + featureJson = + """ + { + "features": { + "autofill": { + "state": "enabled", + "exceptions": [], + "features": { + "credentialsSaving": { + "state": "enabled", + "minSupportedVersion": 2, + "cohorts": [ + { + "name": "control", + "weight": 1 + }, + { + "name": "blue", + "weight": 0 + } + ] + } + } + } + } + } + """.data(using: .utf8)! + manager.reload(etag: "", data: featureJson) + let config = manager.privacyConfig + + // we haven't called getCohortIfEnabled yet, so cohorts should not be yet assigned + XCTAssertNil(mockStore.experiments) + XCTAssertNil(experimentManager.cohort(for: subfeatureName)) + + // we call isSubfeatureEnabled() hould not be assigned either + XCTAssertTrue(config.isSubfeatureEnabled(AutofillSubfeature.credentialsSaving)) + XCTAssertEqual(config.stateFor(AutofillSubfeature.credentialsSaving), .enabled) + XCTAssertNil(mockStore.experiments) + XCTAssertNil(experimentManager.cohort(for: subfeatureName)) + + // we call getCohortIfEnabled(cohort), then we should assign cohort + XCTAssertEqual(featureFlagger.getCohortIfEnabled(for: CredentialsSavingFlag())?.rawValue, CredentialsSavingFlag.CohortType.control.rawValue) + XCTAssertFalse(mockStore.experiments?.isEmpty ?? true) + XCTAssertEqual(experimentManager.cohort(for: subfeatureName), "control") + } + + func testRemoveAllCohortsRemotelyRemovesAssignedCohort() { + featureJson = + """ + { + "features": { + "autofill": { + "state": "enabled", + "exceptions": [], + "features": { + "credentialsSaving": { + "state": "enabled", + "minSupportedVersion": 2, + "cohorts": [ + { + "name": "control", + "weight": 1 + }, + { + "name": "blue", + "weight": 0 + } + ] + } + } + } + } + } + """.data(using: .utf8)! + manager.reload(etag: "", data: featureJson) + var config = manager.privacyConfig + + XCTAssertEqual(featureFlagger.getCohortIfEnabled(for: CredentialsSavingFlag())?.rawValue, CredentialsSavingFlag.CohortType.control.rawValue) + XCTAssertFalse(mockStore.experiments?.isEmpty ?? true) + XCTAssertEqual(experimentManager.cohort(for: subfeatureName), "control") + + // remove blue cohort + featureJson = + """ + { + "features": { + "autofill": { + "state": "enabled", + "exceptions": [], + "features": { + "credentialsSaving": { + "state": "enabled", + "minSupportedVersion": 2, + "cohorts": [ + { + "name": "control", + "weight": 1 + } + ] + } + } + } + } + } + """.data(using: .utf8)! + manager.reload(etag: "", data: featureJson) + config = manager.privacyConfig + + XCTAssertEqual(featureFlagger.getCohortIfEnabled(for: CredentialsSavingFlag())?.rawValue, CredentialsSavingFlag.CohortType.control.rawValue) + XCTAssertFalse(mockStore.experiments?.isEmpty ?? true) + XCTAssertEqual(experimentManager.cohort(for: subfeatureName), "control") + + // remove all remaining cohorts + featureJson = + """ + { + "features": { + "autofill": { + "state": "enabled", + "exceptions": [], + "features": { + "credentialsSaving": { + "state": "enabled", + "minSupportedVersion": 2 + } + } + } + } + } + """.data(using: .utf8)! + manager.reload(etag: "", data: featureJson) + config = manager.privacyConfig + + XCTAssertNil(featureFlagger.getCohortIfEnabled(for: CredentialsSavingFlag())) + XCTAssertTrue(mockStore.experiments?.isEmpty ?? false) + XCTAssertNil(experimentManager.cohort(for: subfeatureName)) + XCTAssertTrue(config.isSubfeatureEnabled(AutofillSubfeature.credentialsSaving)) + } + + func testRemoveAssignedCohortsRemotelyRemovesAssignedCohortAndTriesToReassign() { + featureJson = + """ + { + "features": { + "autofill": { + "state": "enabled", + "exceptions": [], + "features": { + "credentialsSaving": { + "state": "enabled", + "minSupportedVersion": 2, + "cohorts": [ + { + "name": "control", + "weight": 1 + }, + { + "name": "blue", + "weight": 0 + } + ] + } + } + } + } + } + """.data(using: .utf8)! + manager.reload(etag: "2", data: featureJson) + + XCTAssertEqual(featureFlagger.getCohortIfEnabled(for: CredentialsSavingFlag())?.rawValue, CredentialsSavingFlag.CohortType.control.rawValue) + XCTAssertFalse(mockStore.experiments?.isEmpty ?? true) + XCTAssertEqual(experimentManager.cohort(for: subfeatureName), "control") + + // remove blue cohort + featureJson = + """ + { + "features": { + "autofill": { + "state": "enabled", + "exceptions": [], + "features": { + "credentialsSaving": { + "state": "enabled", + "minSupportedVersion": 2, + "cohorts": [ + { + "name": "red", + "weight": 1 + }, + { + "name": "blue", + "weight": 0 + } + ] + } + } + } + } + } + """.data(using: .utf8)! + manager.reload(etag: "2", data: featureJson) + + XCTAssertEqual(featureFlagger.getCohortIfEnabled(for: CredentialsSavingFlag())?.rawValue, CredentialsSavingFlag.CohortType.red.rawValue) + XCTAssertFalse(mockStore.experiments?.isEmpty ?? true) + XCTAssertEqual(experimentManager.cohort(for: subfeatureName), "red") + } + + func testDisablingFeatureDisablesCohort() { + // Initially subfeature for both cohorts is disabled + XCTAssertNil(featureFlagger.getCohortIfEnabled(for: CredentialsSavingFlag())) + XCTAssertNil(mockStore.experiments) + XCTAssertNil(experimentManager.cohort(for: subfeatureName)) + + // When features with cohort the cohort with weight 1 is enabled + featureJson = + """ + { + "features": { + "autofill": { + "state": "enabled", + "exceptions": [], + "features": { + "credentialsSaving": { + "state": "enabled", + "minSupportedVersion": 2, + "cohorts": [ + { + "name": "control", + "weight": 1 + }, + { + "name": "blue", + "weight": 0 + } + ] + } + } + } + } + } + """.data(using: .utf8)! + manager.reload(etag: "", data: featureJson) + + XCTAssertEqual(featureFlagger.getCohortIfEnabled(for: CredentialsSavingFlag())?.rawValue, CredentialsSavingFlag.CohortType.control.rawValue) + XCTAssertFalse(mockStore.experiments?.isEmpty ?? true) + XCTAssertEqual(experimentManager.cohort(for: subfeatureName), "control") + + // If the subfeature is then disabled isSubfeatureEnabled should return false + featureJson = + """ + { + "features": { + "autofill": { + "state": "enabled", + "exceptions": [], + "features": { + "credentialsSaving": { + "state": "disabled", + "minSupportedVersion": 2, + "cohorts": [ + { + "name": "control", + "weight": 1 + }, + { + "name": "blue", + "weight": 0 + } + ] + } + } + } + } + } + """.data(using: .utf8)! + manager.reload(etag: "", data: featureJson) + + XCTAssertNil(featureFlagger.getCohortIfEnabled(for: CredentialsSavingFlag())) + XCTAssertFalse(mockStore.experiments?.isEmpty ?? true) + XCTAssertEqual(experimentManager.cohort(for: subfeatureName), "control") + + // If the subfeature is parent feature disabled isSubfeatureEnabled should return false + featureJson = + """ + { + "features": { + "autofill": { + "state": "disabled", + "exceptions": [], + "features": { + "credentialsSaving": { + "state": "enabled", + "minSupportedVersion": 2, + "cohorts": [ + { + "name": "control", + "weight": 1 + }, + { + "name": "blue", + "weight": 0 + } + ] + } + } + } + } + } + """.data(using: .utf8)! + manager.reload(etag: "", data: featureJson) + + XCTAssertNil(featureFlagger.getCohortIfEnabled(for: CredentialsSavingFlag())) + XCTAssertFalse(mockStore.experiments?.isEmpty ?? true) + XCTAssertEqual(experimentManager.cohort(for: subfeatureName), "control") + } + + func testCohortsAndTargetsInteraction() { + func featureJson(country: String, language: String) -> Data { + """ + { + "features": { + "autofill": { + "state": "enabled", + "exceptions": [], + "features": { + "credentialsSaving": { + "state": "enabled", + "minSupportedVersion": 2, + "targets": [ + { + "localeLanguage": "\(language)", + "localeCountry": "\(country)" + } + ], + "cohorts": [ + { + "name": "control", + "weight": 1 + }, + { + "name": "blue", + "weight": 0 + } + ] + } + } + } + } + } + """.data(using: .utf8)! + } + manager.reload(etag: "", data: featureJson(country: "FR", language: "fr")) + + XCTAssertNil(featureFlagger.getCohortIfEnabled(for: CredentialsSavingFlag())) + XCTAssertNil(mockStore.experiments) + XCTAssertNil(experimentManager.cohort(for: subfeatureName)) + + manager.reload(etag: "", data: featureJson(country: "US", language: "en")) + XCTAssertNil(featureFlagger.getCohortIfEnabled(for: CredentialsSavingFlag())) + XCTAssertNil(mockStore.experiments) + XCTAssertNil(experimentManager.cohort(for: subfeatureName)) + + manager.reload(etag: "", data: featureJson(country: "US", language: "fr")) + XCTAssertEqual(featureFlagger.getCohortIfEnabled(for: CredentialsSavingFlag())?.rawValue, CredentialsSavingFlag.CohortType.control.rawValue) + XCTAssertFalse(mockStore.experiments?.isEmpty ?? true) + XCTAssertEqual(experimentManager.cohort(for: subfeatureName), "control") + + // once cohort is assigned, changing targets shall not affect feature state + manager.reload(etag: "", data: featureJson(country: "IT", language: "it")) + XCTAssertEqual(featureFlagger.getCohortIfEnabled(for: CredentialsSavingFlag())?.rawValue, CredentialsSavingFlag.CohortType.control.rawValue) + XCTAssertFalse(mockStore.experiments?.isEmpty ?? true) + XCTAssertEqual(experimentManager.cohort(for: subfeatureName), "control") + + let featureJson2 = + """ + { + "features": { + "autofill": { + "state": "enabled", + "exceptions": [], + "features": { + "credentialsSaving": { + "state": "enabled", + "minSupportedVersion": 2, + "targets": [ + { + "localeCountry": "FR" + } + ], + } + } + } + } + } + """.data(using: .utf8)! + manager.reload(etag: "", data: featureJson2) + + XCTAssertNil(featureFlagger.getCohortIfEnabled(for: CredentialsSavingFlag())) + XCTAssertTrue(mockStore.experiments?.isEmpty ?? false) + XCTAssertNil(experimentManager.cohort(for: subfeatureName)) + + // re-populate experiment to re-assign new cohort, should not be assigned as it has wrong targets + manager.reload(etag: "", data: featureJson(country: "IT", language: "it")) + XCTAssertNil(featureFlagger.getCohortIfEnabled(for: CredentialsSavingFlag())) + XCTAssertTrue(mockStore.experiments?.isEmpty ?? false) + XCTAssertNil(experimentManager.cohort(for: subfeatureName)) + } + + func testChangeRemoteCohortsAfterAssignmentShouldNoop() { + featureJson = + """ + { + "features": { + "autofill": { + "state": "enabled", + "exceptions": [], + "features": { + "credentialsSaving": { + "state": "enabled", + "minSupportedVersion": 2, + "targets": [ + { + "localeCountry": "US" + } + ], + "cohorts": [ + { + "name": "control", + "weight": 1 + }, + { + "name": "blue", + "weight": 0 + } + ] + } + } + } + } + } + """.data(using: .utf8)! + manager.reload(etag: "", data: featureJson) + + XCTAssertEqual(featureFlagger.getCohortIfEnabled(for: CredentialsSavingFlag())?.rawValue, CredentialsSavingFlag.CohortType.control.rawValue) + XCTAssertFalse(mockStore.experiments?.isEmpty ?? true) + XCTAssertEqual(experimentManager.cohort(for: subfeatureName), "control") + + // changing targets should not change cohort assignment + featureJson = + """ + { + "features": { + "autofill": { + "state": "enabled", + "exceptions": [], + "features": { + "credentialsSaving": { + "state": "enabled", + "minSupportedVersion": 2, + "targets": [ + { + "localeCountry": "IT" + } + ], + "cohorts": [ + { + "name": "control", + "weight": 1 + }, + { + "name": "blue", + "weight": 0 + } + ] + } + } + } + } + } + """.data(using: .utf8)! + manager.reload(etag: "", data: featureJson) + + XCTAssertEqual(featureFlagger.getCohortIfEnabled(for: CredentialsSavingFlag())?.rawValue, CredentialsSavingFlag.CohortType.control.rawValue) + XCTAssertFalse(mockStore.experiments?.isEmpty ?? true) + XCTAssertEqual(experimentManager.cohort(for: subfeatureName), "control") + + // changing cohort weight should not change current assignment + featureJson = + """ + { + "features": { + "autofill": { + "state": "enabled", + "exceptions": [], + "features": { + "credentialsSaving": { + "state": "enabled", + "minSupportedVersion": 2, + "targets": [ + { + "localeCountry": "US" + } + ], + "cohorts": [ + { + "name": "control", + "weight": 0 + }, + { + "name": "blue", + "weight": 1 + } + ] + } + } + } + } + } + """.data(using: .utf8)! + manager.reload(etag: "", data: featureJson) + + XCTAssertEqual(featureFlagger.getCohortIfEnabled(for: CredentialsSavingFlag())?.rawValue, CredentialsSavingFlag.CohortType.control.rawValue) + XCTAssertFalse(mockStore.experiments?.isEmpty ?? true) + XCTAssertEqual(experimentManager.cohort(for: subfeatureName), "control") + + // adding cohorts should not change current assignment + featureJson = + """ + { + "features": { + "autofill": { + "state": "enabled", + "exceptions": [], + "features": { + "credentialsSaving": { + "state": "enabled", + "minSupportedVersion": 2, + "targets": [ + { + "localeCountry": "US" + } + ], + "cohorts": [ + { + "name": "control", + "weight": 1 + }, + { + "name": "blue", + "weight": 1 + }, + { + "name": "red", + "weight": 1 + } + ] + } + } + } + } + } + """.data(using: .utf8)! + manager.reload(etag: "", data: featureJson) + + XCTAssertEqual(featureFlagger.getCohortIfEnabled(for: CredentialsSavingFlag())?.rawValue, CredentialsSavingFlag.CohortType.control.rawValue) + XCTAssertFalse(mockStore.experiments?.isEmpty ?? true) + XCTAssertEqual(experimentManager.cohort(for: subfeatureName), "control") + + } + + func testEnrollmentDate() { + featureJson = + """ + { + "features": { + "autofill": { + "state": "enabled", + "exceptions": [], + "features": { + "credentialsSaving": { + "state": "enabled", + "minSupportedVersion": 2, + "targets": [ + { + "localeCountry": "US" + } + ], + "cohorts": [ + { + "name": "control", + "weight": 1 + }, + { + "name": "blue", + "weight": 0 + } + ] + } + } + } + } + } + """.data(using: .utf8)! + manager.reload(etag: "", data: featureJson) + let config = manager.privacyConfig + + XCTAssertTrue(config.isSubfeatureEnabled(AutofillSubfeature.credentialsSaving)) + XCTAssertTrue(mockStore.experiments?.isEmpty ?? true) + XCTAssertNil(experimentManager.cohort(for: subfeatureName), "control") + + XCTAssertEqual(featureFlagger.getCohortIfEnabled(for: CredentialsSavingFlag())?.rawValue, CredentialsSavingFlag.CohortType.control.rawValue) + let currentTime = Date().timeIntervalSince1970 + let enrollmentTime = mockStore.experiments?[subfeatureName]?.enrollmentDate.timeIntervalSince1970 + + XCTAssertNotNil(enrollmentTime) + if let enrollmentTime = enrollmentTime { + let tolerance: TimeInterval = 60 // 1 minute in seconds + XCTAssertEqual(currentTime, enrollmentTime, accuracy: tolerance) + } + } + + func testRollbackCohortExperiments() { + featureJson = + """ + { + "features": { + "autofill": { + "state": "enabled", + "exceptions": [], + "features": { + "credentialsSaving": { + "state": "enabled", + "minSupportedVersion": 2, + "targets": [ + { + "localeCountry": "US" + } + ], + "rollout": { + "steps": [ + { + "percent": 100 + } + ] + }, + "cohorts": [ + { + "name": "control", + "weight": 1 + }, + { + "name": "blue", + "weight": 0 + } + ] + } + } + } + } + } + """.data(using: .utf8)! + manager.reload(etag: "foo", data: featureJson) + var config = manager.privacyConfig + clearRolloutData(feature: "autofill", subFeature: "credentialsSaving") + + XCTAssertTrue(config.isSubfeatureEnabled(AutofillSubfeature.credentialsSaving)) + XCTAssertEqual(featureFlagger.getCohortIfEnabled(for: CredentialsSavingFlag())?.rawValue, CredentialsSavingFlag.CohortType.control.rawValue) + XCTAssertFalse(mockStore.experiments?.isEmpty ?? true) + XCTAssertEqual(experimentManager.cohort(for: subfeatureName), "control") + + featureJson = + """ + { + "features": { + "autofill": { + "state": "enabled", + "exceptions": [], + "features": { + "credentialsSaving": { + "state": "enabled", + "minSupportedVersion": 2, + "targets": [ + { + "localeCountry": "US" + } + ], + "rollout": { + "steps": [ + { + "percent": 0 + } + ] + }, + "cohorts": [ + { + "name": "control", + "weight": 1 + }, + { + "name": "blue", + "weight": 0 + } + ] + } + } + } + } + } + """.data(using: .utf8)! + manager.reload(etag: "foo", data: featureJson) + config = manager.privacyConfig + + XCTAssertTrue(config.isSubfeatureEnabled(AutofillSubfeature.credentialsSaving)) + XCTAssertEqual(featureFlagger.getCohortIfEnabled(for: CredentialsSavingFlag())?.rawValue, CredentialsSavingFlag.CohortType.control.rawValue) + XCTAssertFalse(mockStore.experiments?.isEmpty ?? true) + XCTAssertEqual(experimentManager.cohort(for: subfeatureName), "control") + + } + + func testCohortEnabledAndStopEnrollmentAndRhenRollBack() { + featureJson = + """ + { + "features": { + "autofill": { + "state": "enabled", + "exceptions": [], + "features": { + "credentialsSaving": { + "state": "enabled", + "minSupportedVersion": 2, + "targets": [ + { + "localeCountry": "US" + } + ], + "cohorts": [ + { + "name": "control", + "weight": 1 + }, + { + "name": "blue", + "weight": 0 + } + ] + } + } + } + } + } + """.data(using: .utf8)! + manager.reload(etag: "foo", data: featureJson) + var config = manager.privacyConfig + + XCTAssertTrue(config.isSubfeatureEnabled(AutofillSubfeature.credentialsSaving)) + XCTAssertEqual(featureFlagger.getCohortIfEnabled(for: CredentialsSavingFlag())?.rawValue, CredentialsSavingFlag.CohortType.control.rawValue) + XCTAssertFalse(mockStore.experiments?.isEmpty ?? true) + XCTAssertEqual(experimentManager.cohort(for: subfeatureName), "control") + + // Stop enrollment, should keep assigned cohorts + featureJson = + """ + { + "features": { + "autofill": { + "state": "enabled", + "exceptions": [], + "features": { + "credentialsSaving": { + "state": "enabled", + "minSupportedVersion": 2, + "targets": [ + { + "localeCountry": "US" + } + ], + "cohorts": [ + { + "name": "control", + "weight": 0 + }, + { + "name": "blue", + "weight": 1 + } + ] + } + } + } + } + } + """.data(using: .utf8)! + manager.reload(etag: "foo", data: featureJson) + config = manager.privacyConfig + + XCTAssertTrue(config.isSubfeatureEnabled(AutofillSubfeature.credentialsSaving)) + XCTAssertEqual(featureFlagger.getCohortIfEnabled(for: CredentialsSavingFlag())?.rawValue, CredentialsSavingFlag.CohortType.control.rawValue) + XCTAssertFalse(mockStore.experiments?.isEmpty ?? true) + XCTAssertEqual(experimentManager.cohort(for: subfeatureName), "control") + + // remove control, should re-allocate to blue + featureJson = + """ + { + "features": { + "autofill": { + "state": "enabled", + "exceptions": [], + "features": { + "credentialsSaving": { + "state": "enabled", + "minSupportedVersion": 2, + "targets": [ + { + "localeCountry": "US" + } + ], + "cohorts": [ + { + "name": "blue", + "weight": 1 + } + ] + } + } + } + } + } + """.data(using: .utf8)! + manager.reload(etag: "foo", data: featureJson) + config = manager.privacyConfig + + XCTAssertTrue(config.isSubfeatureEnabled(AutofillSubfeature.credentialsSaving)) + XCTAssertEqual(featureFlagger.getCohortIfEnabled(for: CredentialsSavingFlag())?.rawValue, CredentialsSavingFlag.CohortType.blue.rawValue) + XCTAssertFalse(mockStore.experiments?.isEmpty ?? true) + XCTAssertEqual(experimentManager.cohort(for: subfeatureName), "blue") + } + + private func clearRolloutData(feature: String, subFeature: String) { + UserDefaults().set(nil, forKey: "config.\(feature).\(subFeature).enabled") + UserDefaults().set(nil, forKey: "config.\(feature).\(subFeature).lastRolloutCount") + } + + func testAllActiveExperimentsEmptyIfNoAssignedExperiment() { + featureJson = + """ + { + "features": { + "autofill": { + "state": "enabled", + "exceptions": [], + "features": { + "credentialsSaving": { + "state": "enabled", + "minSupportedVersion": 2, + "targets": [ + { + "localeCountry": "US" + } + ], + "cohorts": [ + { + "name": "control", + "weight": 1 + }, + { + "name": "blue", + "weight": 0 + } + ] + } + } + } + } + } + """.data(using: .utf8)! + manager.reload(etag: "foo", data: featureJson) + + let activeExperiments = featureFlagger.getAllActiveExperiments() + XCTAssertTrue(activeExperiments.isEmpty) + XCTAssertNil(mockStore.experiments) + } + + func testAllActiveExperimentsReturnsOnlyActiveExperiments() { + var featureJson = + """ + { + "features": { + "autofill": { + "state": "enabled", + "exceptions": [], + "features": { + "credentialsSaving": { + "state": "enabled", + "minSupportedVersion": 2, + "targets": [ + { + "localeCountry": "US" + } + ], + "cohorts": [ + { + "name": "control", + "weight": 1 + }, + { + "name": "blue", + "weight": 0 + } + ] + }, + "inlineIconCredentials": { + "state": "enabled", + "minSupportedVersion": 1, + "targets": [ + { + "localeCountry": "US" + } + ], + "cohorts": [ + { + "name": "control", + "weight": 0 + }, + { + "name": "green", + "weight": 1 + } + ] + }, + "accessCredentialManagement": { + "state": "enabled", + "minSupportedVersion": 3, + "targets": [ + { + "localeCountry": "CA" + } + ], + "cohorts": [ + { + "name": "control", + "weight": 1 + }, + { + "name": "green", + "weight": 0 + } + ] + } + } + } + } + } + """.data(using: .utf8)! + + manager.reload(etag: "foo", data: featureJson) + + XCTAssertEqual(featureFlagger.getCohortIfEnabled(for: CredentialsSavingFlag())?.rawValue, CredentialsSavingFlag.Cohort.control.rawValue) + XCTAssertEqual(featureFlagger.getCohortIfEnabled(for: InlineIconCredentialsFlag())?.rawValue, InlineIconCredentialsFlag.Cohort.green.rawValue) + XCTAssertNil(featureFlagger.getCohortIfEnabled(for: AccessCredentialManagementFlag())) + + XCTAssertFalse(mockStore.experiments?.isEmpty ?? true) + XCTAssertEqual(experimentManager.cohort(for: AutofillSubfeature.credentialsSaving.rawValue), "control") + XCTAssertEqual(experimentManager.cohort(for: AutofillSubfeature.inlineIconCredentials.rawValue), "green") + XCTAssertNil(experimentManager.cohort(for: AutofillSubfeature.accessCredentialManagement.rawValue)) + XCTAssertFalse(mockStore.experiments?.isEmpty ?? true) + + var activeExperiments = featureFlagger.getAllActiveExperiments() + XCTAssertEqual(activeExperiments.count, 2) + XCTAssertEqual(activeExperiments[AutofillSubfeature.credentialsSaving.rawValue]?.cohortID, "control") + XCTAssertEqual(activeExperiments[AutofillSubfeature.inlineIconCredentials.rawValue]?.cohortID, "green") + XCTAssertNil(activeExperiments[AutofillSubfeature.accessCredentialManagement.rawValue]) + + // When an assigned cohort is removed it's not part of active experiments + featureJson = + """ + { + "features": { + "autofill": { + "state": "enabled", + "exceptions": [], + "features": { + "credentialsSaving": { + "state": "enabled", + "minSupportedVersion": 2, + "targets": [ + { + "localeCountry": "US" + } + ], + "cohorts": [ + { + "name": "blue", + "weight": 1 + } + ] + }, + "inlineIconCredentials": { + "state": "enabled", + "minSupportedVersion": 1, + "targets": [ + { + "localeCountry": "US" + } + ], + "cohorts": [ + { + "name": "control", + "weight": 0 + }, + { + "name": "green", + "weight": 1 + } + ] + }, + "accessCredentialManagement": { + "state": "enabled", + "minSupportedVersion": 3, + "targets": [ + { + "localeCountry": "CA" + } + ], + "cohorts": [ + { + "name": "control", + "weight": 1 + }, + { + "name": "green", + "weight": 0 + } + ] + } + } + } + } + } + """.data(using: .utf8)! + + manager.reload(etag: "foo", data: featureJson) + + activeExperiments = featureFlagger.getAllActiveExperiments() + XCTAssertEqual(activeExperiments.count, 1) + XCTAssertNil(activeExperiments[AutofillSubfeature.credentialsSaving.rawValue]) + XCTAssertEqual(activeExperiments[AutofillSubfeature.inlineIconCredentials.rawValue]?.cohortID, "green") + XCTAssertNil(activeExperiments[AutofillSubfeature.accessCredentialManagement.rawValue]) + + // When feature disabled an assigned cohort it's not part of active experiments + featureJson = + """ + { + "features": { + "autofill": { + "state": "enabled", + "exceptions": [], + "features": { + "credentialsSaving": { + "state": "enabled", + "minSupportedVersion": 2, + "targets": [ + { + "localeCountry": "US" + } + ], + "cohorts": [ + { + "name": "blue", + "weight": 1 + } + ] + }, + "inlineIconCredentials": { + "state": "disabled", + "minSupportedVersion": 1, + "targets": [ + { + "localeCountry": "US" + } + ], + "cohorts": [ + { + "name": "control", + "weight": 0 + }, + { + "name": "green", + "weight": 1 + } + ] + }, + "accessCredentialManagement": { + "state": "enabled", + "minSupportedVersion": 3, + "targets": [ + { + "localeCountry": "CA" + } + ], + "cohorts": [ + { + "name": "control", + "weight": 1 + }, + { + "name": "green", + "weight": 0 + } + ] + } + } + } + } + } + """.data(using: .utf8)! + + manager.reload(etag: "foo", data: featureJson) + + activeExperiments = featureFlagger.getAllActiveExperiments() + XCTAssertTrue(activeExperiments.isEmpty) + } + +} diff --git a/Tests/BrowserServicesKitTests/FeatureFlagging/TestFeatureFlag.swift b/Tests/BrowserServicesKitTests/FeatureFlagging/TestFeatureFlag.swift index 8c1a43ee7..8d8ec9929 100644 --- a/Tests/BrowserServicesKitTests/FeatureFlagging/TestFeatureFlag.swift +++ b/Tests/BrowserServicesKitTests/FeatureFlagging/TestFeatureFlag.swift @@ -35,11 +35,11 @@ enum TestFeatureFlag: String, FeatureFlagDescribing { var source: FeatureFlagSource { switch self { case .nonOverridableFlag: - return .internalOnly + return .internalOnly() case .overridableFlagDisabledByDefault: return .disabled case .overridableFlagEnabledByDefault: - return .internalOnly + return .internalOnly() } } } diff --git a/Tests/BrowserServicesKitTests/PrivacyConfig/AppPrivacyConfigurationTests.swift b/Tests/BrowserServicesKitTests/PrivacyConfig/AppPrivacyConfigurationTests.swift index 252a7b866..9388b14e9 100644 --- a/Tests/BrowserServicesKitTests/PrivacyConfig/AppPrivacyConfigurationTests.swift +++ b/Tests/BrowserServicesKitTests/PrivacyConfig/AppPrivacyConfigurationTests.swift @@ -907,6 +907,186 @@ class AppPrivacyConfigurationTests: XCTestCase { XCTAssertEqual(configAfterUpdate.stateFor(AutofillSubfeature.credentialsSaving, randomizer: mockRandom(in:)), .disabled(.disabledInConfig)) } + let exampleSubfeatureEnabledWithTarget = + """ + { + "features": { + "autofill": { + "state": "enabled", + "exceptions": [], + "features": { + "credentialsSaving": { + "state": "enabled", + "targets": [ + { + "localeCountry": "US", + "localeLanguage": "fr" + } + ] + } + } + } + } + } + """.data(using: .utf8)! + + func testWhenCheckingSubfeatureStateWithSubfeatureEnabledWhenTargetMatches_SubfeatureShouldBeEnabled() { + let mockEmbeddedData = MockEmbeddedDataProvider(data: exampleSubfeatureEnabledWithTarget, etag: "test") + let mockInternalUserStore = MockInternalUserStoring() + let locale = Locale(identifier: "fr_US") + + let manager = PrivacyConfigurationManager(fetchedETag: nil, + fetchedData: nil, + embeddedDataProvider: mockEmbeddedData, + localProtection: MockDomainsProtectionStore(), + internalUserDecider: DefaultInternalUserDecider(store: mockInternalUserStore), + locale: locale) + let config = manager.privacyConfig + + XCTAssertTrue(config.isSubfeatureEnabled(AutofillSubfeature.credentialsSaving)) + XCTAssertEqual(config.stateFor(AutofillSubfeature.credentialsSaving), .enabled) + } + + func testWhenCheckingSubfeatureStateWithSubfeatureEnabledWhenRegionDoesNotMatches_SubfeatureShouldBeDisabled() { + let mockEmbeddedData = MockEmbeddedDataProvider(data: exampleSubfeatureEnabledWithTarget, etag: "test") + let mockInternalUserStore = MockInternalUserStoring() + let locale = Locale(identifier: "fr_FR") + + let manager = PrivacyConfigurationManager(fetchedETag: nil, + fetchedData: nil, + embeddedDataProvider: mockEmbeddedData, + localProtection: MockDomainsProtectionStore(), + internalUserDecider: DefaultInternalUserDecider(store: mockInternalUserStore), + locale: locale) + let config = manager.privacyConfig + + XCTAssertFalse(config.isSubfeatureEnabled(AutofillSubfeature.credentialsSaving)) + XCTAssertEqual(config.stateFor(AutofillSubfeature.credentialsSaving), .disabled(.targetDoesNotMatch)) + } + + func testWhenCheckingSubfeatureStateWithSubfeatureEnabledWhenLanguageDoesNotMatches_SubfeatureShouldBeDisabled() { + let mockEmbeddedData = MockEmbeddedDataProvider(data: exampleSubfeatureEnabledWithTarget, etag: "test") + let mockInternalUserStore = MockInternalUserStoring() + let locale = Locale(identifier: "it_US") + + let manager = PrivacyConfigurationManager(fetchedETag: nil, + fetchedData: nil, + embeddedDataProvider: mockEmbeddedData, + localProtection: MockDomainsProtectionStore(), + internalUserDecider: DefaultInternalUserDecider(store: mockInternalUserStore), + locale: locale) + let config = manager.privacyConfig + + XCTAssertFalse(config.isSubfeatureEnabled(AutofillSubfeature.credentialsSaving)) + XCTAssertEqual(config.stateFor(AutofillSubfeature.credentialsSaving), .disabled(.targetDoesNotMatch)) + } + + let exampleSubfeatureDisabledWithTarget = + """ + { + "features": { + "autofill": { + "state": "enabled", + "exceptions": [], + "features": { + "credentialsSaving": { + "state": "disabled", + "targets": [ + { + "localeCountry": "US", + "localeLanguage": "fr" + } + ] + } + } + } + } + } + """.data(using: .utf8)! + + func testWhenCheckingSubfeatureStateWithSubfeatureDisabledWhenTargetMatches_SubfeatureShouldBeDisabled() { + let mockEmbeddedData = MockEmbeddedDataProvider(data: exampleSubfeatureDisabledWithTarget, etag: "test") + let mockInternalUserStore = MockInternalUserStoring() + let locale = Locale(identifier: "fr_US") + + let manager = PrivacyConfigurationManager(fetchedETag: nil, + fetchedData: nil, + embeddedDataProvider: mockEmbeddedData, + localProtection: MockDomainsProtectionStore(), + internalUserDecider: DefaultInternalUserDecider(store: mockInternalUserStore), + locale: locale) + let config = manager.privacyConfig + + XCTAssertFalse(config.isSubfeatureEnabled(AutofillSubfeature.credentialsSaving)) + XCTAssertEqual(config.stateFor(AutofillSubfeature.credentialsSaving), .disabled(.disabledInConfig)) + } + + let exampleSubfeatureEnabledWithRolloutAndTarget = + """ + { + "features": { + "autofill": { + "state": "enabled", + "exceptions": [], + "features": { + "credentialsSaving": { + "state": "enabled", + "targets": [ + { + "localeCountry": "US", + "localeLanguage": "fr" + } + ], + "rollout": { + "steps": [{ + "percent": 5.0 + }] + } + } + } + } + } + } + """.data(using: .utf8)! + + func testWhenCheckingSubfeatureStateWithSubfeatureEnabledAndTargetMatchesWhenNotInRollout_SubfeatureShouldBeDisabled() { + let mockEmbeddedData = MockEmbeddedDataProvider(data: exampleSubfeatureEnabledWithRolloutAndTarget, etag: "test") + let mockInternalUserStore = MockInternalUserStoring() + let locale = Locale(identifier: "fr_US") + mockRandomValue = 7.0 + clearRolloutData(feature: "autofill", subFeature: "credentialsSaving") + + let manager = PrivacyConfigurationManager(fetchedETag: nil, + fetchedData: nil, + embeddedDataProvider: mockEmbeddedData, + localProtection: MockDomainsProtectionStore(), + internalUserDecider: DefaultInternalUserDecider(store: mockInternalUserStore), + locale: locale) + let config = manager.privacyConfig + + XCTAssertFalse(config.isSubfeatureEnabled(AutofillSubfeature.credentialsSaving, randomizer: mockRandom(in:))) + XCTAssertEqual(config.stateFor(AutofillSubfeature.credentialsSaving), .disabled(.stillInRollout)) + } + + func testWhenCheckingSubfeatureStateWithSubfeatureEnabledAndTargetMatchesWhenInRollout_SubfeatureShouldBeEnabled() { + let mockEmbeddedData = MockEmbeddedDataProvider(data: exampleSubfeatureEnabledWithRolloutAndTarget, etag: "test") + let mockInternalUserStore = MockInternalUserStoring() + let locale = Locale(identifier: "fr_US") + mockRandomValue = 2.0 + clearRolloutData(feature: "autofill", subFeature: "credentialsSaving") + + let manager = PrivacyConfigurationManager(fetchedETag: nil, + fetchedData: nil, + embeddedDataProvider: mockEmbeddedData, + localProtection: MockDomainsProtectionStore(), + internalUserDecider: DefaultInternalUserDecider(store: mockInternalUserStore), + locale: locale) + let config = manager.privacyConfig + + XCTAssertTrue(config.isSubfeatureEnabled(AutofillSubfeature.credentialsSaving, randomizer: mockRandom(in:))) + XCTAssertEqual(config.stateFor(AutofillSubfeature.credentialsSaving), .enabled) + } + let exampleEnabledSubfeatureWithRollout = """ { diff --git a/Tests/BrowserServicesKitTests/PrivacyConfig/ExperimentCohortsManagerTests.swift b/Tests/BrowserServicesKitTests/PrivacyConfig/ExperimentCohortsManagerTests.swift deleted file mode 100644 index 518249560..000000000 --- a/Tests/BrowserServicesKitTests/PrivacyConfig/ExperimentCohortsManagerTests.swift +++ /dev/null @@ -1,266 +0,0 @@ -// -// ExperimentCohortsManagerTests.swift -// -// Copyright © 2024 DuckDuckGo. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// - -import XCTest -@testable import BrowserServicesKit - -final class ExperimentCohortsManagerTests: XCTestCase { - - var mockStore: MockExperimentDataStore! - var experimentCohortsManager: ExperimentCohortsManager! - - let subfeatureName1 = "TestSubfeature1" - var experimentData1: ExperimentData! - - let subfeatureName2 = "TestSubfeature2" - var experimentData2: ExperimentData! - - let encoder: JSONEncoder = { - let encoder = JSONEncoder() - encoder.dateEncodingStrategy = .secondsSince1970 - return encoder - }() - - override func setUp() { - super.setUp() - mockStore = MockExperimentDataStore() - experimentCohortsManager = ExperimentCohortsManager( - store: mockStore, - randomizer: { _ in 50.0 } - ) - - let expectedDate1 = Date() - experimentData1 = ExperimentData(cohort: "TestCohort1", enrollmentDate: expectedDate1) - - let expectedDate2 = Date().addingTimeInterval(60) - experimentData2 = ExperimentData(cohort: "TestCohort2", enrollmentDate: expectedDate2) - } - - override func tearDown() { - mockStore = nil - experimentCohortsManager = nil - experimentData1 = nil - experimentData2 = nil - super.tearDown() - } - - func testCohortReturnsCohortIDIfExistsForMultipleSubfeatures() { - // GIVEN - mockStore.experiments = [subfeatureName1: experimentData1, subfeatureName2: experimentData2] - - // WHEN - let result1 = experimentCohortsManager.cohort(for: subfeatureName1) - let result2 = experimentCohortsManager.cohort(for: subfeatureName2) - - // THEN - XCTAssertEqual(result1, experimentData1.cohort) - XCTAssertEqual(result2, experimentData2.cohort) - } - - func testEnrollmentDateReturnsCorrectDateIfExists() { - // GIVEN - mockStore.experiments = [subfeatureName1: experimentData1] - - // WHEN - let result1 = experimentCohortsManager.enrollmentDate(for: subfeatureName1) - let result2 = experimentCohortsManager.enrollmentDate(for: subfeatureName2) - - // THEN - let timeDifference1 = abs(experimentData1.enrollmentDate.timeIntervalSince(result1 ?? Date())) - - XCTAssertLessThanOrEqual(timeDifference1, 1.0, "Expected enrollment date for subfeatureName1 to match at the second level") - XCTAssertNil(result2) - } - - func testCohortReturnsNilIfCohortDoesNotExist() { - // GIVEN - let subfeatureName = "TestSubfeature" - - // WHEN - let result = experimentCohortsManager.cohort(for: subfeatureName) - - // THEN - XCTAssertNil(result) - } - - func testEnrollmentDateReturnsNilIfDateDoesNotExist() { - // GIVEN - let subfeatureName = "TestSubfeature" - - // WHEN - let result = experimentCohortsManager.enrollmentDate(for: subfeatureName) - - // THEN - XCTAssertNil(result) - } - - func testRemoveCohortSuccessfullyRemovesData() throws { - // GIVEN - mockStore.experiments = [subfeatureName1: experimentData1] - - // WHEN - experimentCohortsManager.removeCohort(from: subfeatureName1) - - // THEN - let experiments = try XCTUnwrap(mockStore.experiments) - XCTAssertTrue(experiments.isEmpty) - } - - func testRemoveCohortDoesNothingIfSubfeatureDoesNotExist() { - // GIVEN - let expectedExperiments: Experiments = [subfeatureName1: experimentData1, subfeatureName2: experimentData2] - mockStore.experiments = expectedExperiments - - // WHEN - experimentCohortsManager.removeCohort(from: "someOtherSubfeature") - - // THEN - XCTAssertEqual( mockStore.experiments, expectedExperiments) - } - - func testAssignCohortReturnsNilIfNoCohorts() { - // GIVEN - let subfeature = ExperimentSubfeature(subfeatureID: subfeatureName1, cohorts: []) - - // WHEN - let result = experimentCohortsManager.assignCohort(to: subfeature) - - // THEN - XCTAssertNil(result) - } - - func testAssignCohortReturnsNilIfAllWeightsAreZero() { - // GIVEN - let jsonCohort1: [String: Any] = ["name": "TestCohort", "weight": 0] - let jsonCohort2: [String: Any] = ["name": "TestCohort", "weight": 0] - let cohorts = [ - PrivacyConfigurationData.Cohort(json: jsonCohort1)!, - PrivacyConfigurationData.Cohort(json: jsonCohort2)! - ] - let subfeature = ExperimentSubfeature(subfeatureID: subfeatureName1, cohorts: cohorts) - - // WHEN - let result = experimentCohortsManager.assignCohort(to: subfeature) - - // THEN - XCTAssertNil(result) - } - - func testAssignCohortSelectsCorrectCohortBasedOnWeight() { - // Cohort1 has weight 1, Cohort2 has weight 3 - // Total weight is 1 + 3 = 4 - let jsonCohort1: [String: Any] = ["name": "Cohort1", "weight": 1] - let jsonCohort2: [String: Any] = ["name": "Cohort2", "weight": 3] - let cohorts = [ - PrivacyConfigurationData.Cohort(json: jsonCohort1)!, - PrivacyConfigurationData.Cohort(json: jsonCohort2)! - ] - let subfeature = ExperimentSubfeature(subfeatureID: subfeatureName1, cohorts: cohorts) - let expectedTotalWeight = 4.0 - - // Use a custom randomizer to verify the range - experimentCohortsManager = ExperimentCohortsManager( - store: mockStore, - randomizer: { range in - // Assert that the range lower bound is 0 - XCTAssertEqual(range.lowerBound, 0.0) - // Assert that the range upper bound is the total weight - XCTAssertEqual(range.upperBound, expectedTotalWeight) - return 0.0 - } - ) - - // Test case where random value is at the very start of Cohort1's range (0) - experimentCohortsManager = ExperimentCohortsManager( - store: mockStore, - randomizer: { _ in 0.0 } - ) - let resultStartOfCohort1 = experimentCohortsManager.assignCohort(to: subfeature) - XCTAssertEqual(resultStartOfCohort1, "Cohort1") - - // Test case where random value is at the end of Cohort1's range (0.9) - experimentCohortsManager = ExperimentCohortsManager( - store: mockStore, - randomizer: { _ in 0.9 } - ) - let resultEndOfCohort1 = experimentCohortsManager.assignCohort(to: subfeature) - XCTAssertEqual(resultEndOfCohort1, "Cohort1") - - // Test case where random value is at the start of Cohort2's range (1.00 to 4) - experimentCohortsManager = ExperimentCohortsManager( - store: mockStore, - randomizer: { _ in 1.00 } - ) - let resultStartOfCohort2 = experimentCohortsManager.assignCohort(to: subfeature) - XCTAssertEqual(resultStartOfCohort2, "Cohort2") - - // Test case where random value falls exactly within Cohort2's range (2.5) - experimentCohortsManager = ExperimentCohortsManager( - store: mockStore, - randomizer: { _ in 2.5 } - ) - let resultMiddleOfCohort2 = experimentCohortsManager.assignCohort(to: subfeature) - XCTAssertEqual(resultMiddleOfCohort2, "Cohort2") - - // Test case where random value is at the end of Cohort2's range (4) - experimentCohortsManager = ExperimentCohortsManager( - store: mockStore, - randomizer: { _ in 3.9 } - ) - let resultEndOfCohort2 = experimentCohortsManager.assignCohort(to: subfeature) - XCTAssertEqual(resultEndOfCohort2, "Cohort2") - } - - func testAssignCohortWithSingleCohortAlwaysSelectsThatCohort() throws { - // GIVEN - let jsonCohort1: [String: Any] = ["name": "Cohort1", "weight": 1] - let cohorts = [ - PrivacyConfigurationData.Cohort(json: jsonCohort1)! - ] - let subfeature = ExperimentSubfeature(subfeatureID: subfeatureName1, cohorts: cohorts) - let expectedTotalWeight = 1.0 - - // Use a custom randomizer to verify the range - experimentCohortsManager = ExperimentCohortsManager( - store: mockStore, - randomizer: { range in - // Assert that the range lower bound is 0 - XCTAssertEqual(range.lowerBound, 0.0) - // Assert that the range upper bound is the total weight - XCTAssertEqual(range.upperBound, expectedTotalWeight) - return 0.0 - } - ) - - // WHEN - experimentCohortsManager = ExperimentCohortsManager( - store: mockStore, - randomizer: { range in Double.random(in: range)} - ) - let result = experimentCohortsManager.assignCohort(to: subfeature) - - // THEN - XCTAssertEqual(result, "Cohort1") - XCTAssertEqual(cohorts[0].name, mockStore.experiments?[subfeature.subfeatureID]?.cohort) - } - -} - -class MockExperimentDataStore: ExperimentsDataStoring { - var experiments: Experiments? -} diff --git a/Tests/BrowserServicesKitTests/PrivacyConfig/PrivacyConfigurationDataTests.swift b/Tests/BrowserServicesKitTests/PrivacyConfig/PrivacyConfigurationDataTests.swift index fac814b1c..08ecbec85 100644 --- a/Tests/BrowserServicesKitTests/PrivacyConfig/PrivacyConfigurationDataTests.swift +++ b/Tests/BrowserServicesKitTests/PrivacyConfig/PrivacyConfigurationDataTests.swift @@ -68,6 +68,9 @@ class PrivacyConfigurationDataTests: XCTestCase { XCTAssertEqual(subfeatures["enabledSubfeature"]?.cohorts?.count, 3) XCTAssertEqual(subfeatures["enabledSubfeature"]?.cohorts?[0].name, "myExperimentControl") XCTAssertEqual(subfeatures["enabledSubfeature"]?.cohorts?[0].weight, 1) + XCTAssertEqual(subfeatures["enabledSubfeature"]?.targets?[0].localeCountry, "US") + XCTAssertEqual(subfeatures["enabledSubfeature"]?.targets?[0].localeLanguage, "fr") + XCTAssertEqual(subfeatures["enabledSubfeature"]?.settings, "{\"foo\":\"foo\\/value\",\"bar\":\"bar\\/value\"}") XCTAssertEqual(subfeatures["internalSubfeature"]?.state, "internal") } else { XCTFail("Could not parse subfeatures") diff --git a/Tests/BrowserServicesKitTests/Resources/privacy-config-example.json b/Tests/BrowserServicesKitTests/Resources/privacy-config-example.json index 3fd5be5a5..735c3914f 100644 --- a/Tests/BrowserServicesKitTests/Resources/privacy-config-example.json +++ b/Tests/BrowserServicesKitTests/Resources/privacy-config-example.json @@ -171,6 +171,16 @@ }, "enabledSubfeature": { "state": "enabled", + "targets": [ + { + "localeCountry": "US", + "localeLanguage": "fr" + } + ], + "settings": { + "foo": "foo/value", + "bar": "bar/value" + }, "description": "A description of the sub-feature", "cohorts": [ { diff --git a/Tests/BrowserServicesKitTests/Resources/privacy-reference-tests b/Tests/BrowserServicesKitTests/Resources/privacy-reference-tests index 6133e7d9d..a603ff9af 160000 --- a/Tests/BrowserServicesKitTests/Resources/privacy-reference-tests +++ b/Tests/BrowserServicesKitTests/Resources/privacy-reference-tests @@ -1 +1 @@ -Subproject commit 6133e7d9d9cd5f1b925cab1971b4d785dc639df7 +Subproject commit a603ff9af22ca3ff7ce2e7ffbfe18c447d9f23e8 diff --git a/Tests/BrowserServicesKitTests/Subscription/SubscriptionFeatureAvailabilityTests.swift b/Tests/BrowserServicesKitTests/Subscription/SubscriptionFeatureAvailabilityTests.swift index a85dadc1e..388ad69b9 100644 --- a/Tests/BrowserServicesKitTests/Subscription/SubscriptionFeatureAvailabilityTests.swift +++ b/Tests/BrowserServicesKitTests/Subscription/SubscriptionFeatureAvailabilityTests.swift @@ -221,17 +221,29 @@ class MockPrivacyConfiguration: PrivacyConfiguration { var isSubfeatureEnabledCheck: ((any PrivacySubfeature) -> Bool)? - func isSubfeatureEnabled(_ subfeature: any PrivacySubfeature, versionProvider: AppVersionProvider, randomizer: (Range) -> Double) -> Bool { + func isSubfeatureEnabled(_ subfeature: any BrowserServicesKit.PrivacySubfeature, versionProvider: BrowserServicesKit.AppVersionProvider, randomizer: (Range) -> Double) -> Bool { isSubfeatureEnabledCheck?(subfeature) ?? false } - func stateFor(_ subfeature: any PrivacySubfeature, versionProvider: AppVersionProvider, randomizer: (Range) -> Double) -> PrivacyConfigurationFeatureState { + func stateFor(_ subfeature: any BrowserServicesKit.PrivacySubfeature, versionProvider: BrowserServicesKit.AppVersionProvider, randomizer: (Range) -> Double) -> BrowserServicesKit.PrivacyConfigurationFeatureState { if isSubfeatureEnabledCheck?(subfeature) == true { return .enabled } return .disabled(.disabledInConfig) } + func stateFor(subfeatureID: BrowserServicesKit.SubfeatureID, parentFeatureID: BrowserServicesKit.ParentFeatureID, versionProvider: BrowserServicesKit.AppVersionProvider, randomizer: (Range) -> Double) -> BrowserServicesKit.PrivacyConfigurationFeatureState { + return .enabled + } + + func cohorts(for subfeature: any BrowserServicesKit.PrivacySubfeature) -> [BrowserServicesKit.PrivacyConfigurationData.Cohort]? { + return nil + } + + func cohorts(subfeatureID: BrowserServicesKit.SubfeatureID, parentFeatureID: BrowserServicesKit.ParentFeatureID) -> [BrowserServicesKit.PrivacyConfigurationData.Cohort]? { + return nil + } + var identifier: String = "abcd" var version: String? = "123456789" var userUnprotectedDomains: [String] = [] diff --git a/Tests/CommonTests/Extensions/StringExtensionTests.swift b/Tests/CommonTests/Extensions/StringExtensionTests.swift index bcf415895..65abb79c8 100644 --- a/Tests/CommonTests/Extensions/StringExtensionTests.swift +++ b/Tests/CommonTests/Extensions/StringExtensionTests.swift @@ -16,8 +16,10 @@ // limitations under the License. // +import CryptoKit import Foundation import XCTest + @testable import Common final class StringExtensionTests: XCTestCase { @@ -370,4 +372,13 @@ final class StringExtensionTests: XCTestCase { } } + func testSha256() { + let string = "Hello, World! This is a test string." + let hash = string.sha256 + let expected = "3c2b805ab0038afb0629e1d598ae73e0caabb69de03e96762977d34e8ba428bf" + let expectedSHA256 = SHA256.hash(data: Data(string.utf8)).map { String(format: "%02hhx", $0) }.joined() + XCTAssertEqual(hash, expected) + XCTAssertEqual(hash, expectedSHA256) + } + } diff --git a/Tests/DDGSyncTests/Mocks/Mocks.swift b/Tests/DDGSyncTests/Mocks/Mocks.swift index 013123c6d..4745f155c 100644 --- a/Tests/DDGSyncTests/Mocks/Mocks.swift +++ b/Tests/DDGSyncTests/Mocks/Mocks.swift @@ -169,18 +169,26 @@ class MockPrivacyConfiguration: PrivacyConfiguration { return .enabled } - func isSubfeatureEnabled( - _ subfeature: any PrivacySubfeature, - versionProvider: AppVersionProvider, - randomizer: (Range) -> Double - ) -> Bool { + func isSubfeatureEnabled(_ subfeature: any BrowserServicesKit.PrivacySubfeature, versionProvider: BrowserServicesKit.AppVersionProvider, randomizer: (Range) -> Double) -> Bool { true } - func stateFor(_ subfeature: any PrivacySubfeature, versionProvider: AppVersionProvider, randomizer: (Range) -> Double) -> PrivacyConfigurationFeatureState { + func stateFor(_ subfeature: any BrowserServicesKit.PrivacySubfeature, versionProvider: BrowserServicesKit.AppVersionProvider, randomizer: (Range) -> Double) -> BrowserServicesKit.PrivacyConfigurationFeatureState { return .enabled } + func stateFor(subfeatureID: BrowserServicesKit.SubfeatureID, parentFeatureID: BrowserServicesKit.ParentFeatureID, versionProvider: BrowserServicesKit.AppVersionProvider, randomizer: (Range) -> Double) -> BrowserServicesKit.PrivacyConfigurationFeatureState { + return .enabled + } + + func cohorts(for subfeature: any BrowserServicesKit.PrivacySubfeature) -> [BrowserServicesKit.PrivacyConfigurationData.Cohort]? { + return nil + } + + func cohorts(subfeatureID: BrowserServicesKit.SubfeatureID, parentFeatureID: BrowserServicesKit.ParentFeatureID) -> [BrowserServicesKit.PrivacyConfigurationData.Cohort]? { + return nil + } + var identifier: String = "abcd" var version: String? = "123456789" var userUnprotectedDomains: [String] = [] diff --git a/Tests/MaliciousSiteProtectionTests/MaliciousSiteDetectorTests.swift b/Tests/MaliciousSiteProtectionTests/MaliciousSiteDetectorTests.swift new file mode 100644 index 000000000..f6b0de23a --- /dev/null +++ b/Tests/MaliciousSiteProtectionTests/MaliciousSiteDetectorTests.swift @@ -0,0 +1,103 @@ +// +// MaliciousSiteDetectorTests.swift +// +// Copyright © 2024 DuckDuckGo. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +import Foundation +import XCTest + +@testable import MaliciousSiteProtection + +class MaliciousSiteDetectorTests: XCTestCase { + + private var mockAPIClient: MockMaliciousSiteProtectionAPIClient! + private var mockDataManager: MockMaliciousSiteProtectionDataManager! + private var mockEventMapping: MockEventMapping! + private var detector: MaliciousSiteDetector! + + override func setUp() async throws { + mockAPIClient = MockMaliciousSiteProtectionAPIClient() + mockDataManager = MockMaliciousSiteProtectionDataManager() + mockEventMapping = MockEventMapping() + detector = MaliciousSiteDetector(apiClient: mockAPIClient, dataManager: mockDataManager, eventMapping: mockEventMapping) + } + + override func tearDown() async throws { + mockAPIClient = nil + mockDataManager = nil + mockEventMapping = nil + detector = nil + } + + func testIsMaliciousWithLocalFilterHit() async { + let filter = Filter(hash: "255a8a793097aeea1f06a19c08cde28db0eb34c660c6e4e7480c9525d034b16d", regex: ".*malicious.*") + await mockDataManager.store(FilterDictionary(revision: 0, items: [filter]), for: .filterSet(threatKind: .phishing)) + await mockDataManager.store(HashPrefixSet(revision: 0, items: ["255a8a79"]), for: .hashPrefixes(threatKind: .phishing)) + + let url = URL(string: "https://malicious.com/")! + + let result = await detector.evaluate(url) + + XCTAssertEqual(result, .phishing) + } + + func testIsMaliciousWithApiMatch() async { + await mockDataManager.store(FilterDictionary(revision: 0, items: []), for: .filterSet(threatKind: .phishing)) + await mockDataManager.store(HashPrefixSet(revision: 0, items: ["a379a6f6"]), for: .hashPrefixes(threatKind: .phishing)) + + let url = URL(string: "https://example.com/mal")! + + let result = await detector.evaluate(url) + + XCTAssertEqual(result, .phishing) + } + + func testIsMaliciousWithHashPrefixMatch() async { + let filter = Filter(hash: "notamatch", regex: ".*malicious.*") + await mockDataManager.store(FilterDictionary(revision: 0, items: [filter]), for: .filterSet(threatKind: .phishing)) + await mockDataManager.store(HashPrefixSet(revision: 0, items: ["4c64eb24" /* matches safe.com */]), for: .hashPrefixes(threatKind: .phishing)) + + let url = URL(string: "https://safe.com")! + + let result = await detector.evaluate(url) + + XCTAssertNil(result) + } + + func testIsMaliciousWithFullHashMatch() async { + // 4c64eb2468bcd3e113b37167e6b819aeccf550f974a6082ef17fb74ca68e823b + let filter = Filter(hash: "4c64eb2468bcd3e113b37167e6b819aeccf550f974a6082ef17fb74ca68e823b", regex: "https://safe.com/maliciousURI") + await mockDataManager.store(FilterDictionary(revision: 0, items: [filter]), for: .filterSet(threatKind: .phishing)) + await mockDataManager.store(HashPrefixSet(revision: 0, items: ["4c64eb24"]), for: .hashPrefixes(threatKind: .phishing)) + + let url = URL(string: "https://safe.com")! + + let result = await detector.evaluate(url) + + XCTAssertNil(result) + } + + func testIsMaliciousWithNoHashPrefixMatch() async { + let filter = Filter(hash: "testHash", regex: ".*malicious.*") + await mockDataManager.store(FilterDictionary(revision: 0, items: [filter]), for: .filterSet(threatKind: .phishing)) + await mockDataManager.store(HashPrefixSet(revision: 0, items: ["testPrefix"]), for: .hashPrefixes(threatKind: .phishing)) + + let url = URL(string: "https://safe.com")! + + let result = await detector.evaluate(url) + + XCTAssertNil(result) + } +} diff --git a/Tests/MaliciousSiteProtectionTests/MaliciousSiteProtectionAPIClientTests.swift b/Tests/MaliciousSiteProtectionTests/MaliciousSiteProtectionAPIClientTests.swift new file mode 100644 index 000000000..fcea80939 --- /dev/null +++ b/Tests/MaliciousSiteProtectionTests/MaliciousSiteProtectionAPIClientTests.swift @@ -0,0 +1,143 @@ +// +// MaliciousSiteProtectionAPIClientTests.swift +// +// Copyright © 2024 DuckDuckGo. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +import Foundation +import Networking +import TestUtils +import XCTest + +@testable import MaliciousSiteProtection + +final class MaliciousSiteProtectionAPIClientTests: XCTestCase { + + var mockService: MockAPIService! + var client: MaliciousSiteProtection.APIClient! + + override func setUp() { + super.setUp() + mockService = MockAPIService() + client = .init(environment: MaliciousSiteDetector.APIEnvironment.staging, service: mockService) + } + + override func tearDown() { + mockService = nil + client = nil + super.tearDown() + } + + func testWhenPhishingFilterSetRequestedAndSucceeds_ChangeSetIsReturned() async throws { + // Given + let insertFilter = Filter(hash: "a379a6f6eeafb9a55e378c118034e2751e682fab9f2d30ab13d2125586ce1947", regex: ".") + let deleteFilter = Filter(hash: "6a929cd0b3ba4677eaedf1b2bdaf3ff89281cca94f688c83103bc9a676aea46d", regex: "(?i)^https?\\:\\/\\/[\\w\\-\\.]+(?:\\:(?:80|443))?") + let expectedResponse = APIClient.Response.FiltersChangeSet(insert: [insertFilter], delete: [deleteFilter], revision: 666, replace: false) + mockService.requestHandler = { [unowned self] in + XCTAssertEqual($0.urlRequest.url, client.environment.url(for: .filterSet(.init(threatKind: .phishing, revision: 666)))) + let data = try? JSONEncoder().encode(expectedResponse) + let response = HTTPURLResponse(url: $0.urlRequest.url!, statusCode: 200, httpVersion: nil, headerFields: nil)! + return .success(.init(data: data, httpResponse: response)) + } + + // When + let response = try await client.filtersChangeSet(for: .phishing, revision: 666) + + // Then + XCTAssertEqual(response, expectedResponse) + } + + func testWhenHashPrefixesRequestedAndSucceeds_ChangeSetIsReturned() async throws { + // Given + let expectedResponse = APIClient.Response.HashPrefixesChangeSet(insert: ["abc"], delete: ["def"], revision: 1, replace: false) + mockService.requestHandler = { [unowned self] in + XCTAssertEqual($0.urlRequest.url, client.environment.url(for: .hashPrefixSet(.init(threatKind: .phishing, revision: 1)))) + let data = try? JSONEncoder().encode(expectedResponse) + let response = HTTPURLResponse(url: $0.urlRequest.url!, statusCode: 200, httpVersion: nil, headerFields: nil)! + return .success(.init(data: data, httpResponse: response)) + } + + // When + let response = try await client.hashPrefixesChangeSet(for: .phishing, revision: 1) + + // Then + XCTAssertEqual(response, expectedResponse) + } + + func testWhenMatchesRequestedAndSucceeds_MatchesAreReturned() async throws { + // Given + let expectedResponse = APIClient.Response.Matches(matches: [Match(hostname: "example.com", url: "https://example.com/test", regex: ".", hash: "a379a6f6eeafb9a55e378c118034e2751e682fab9f2d30ab13d2125586ce1947", category: nil)]) + mockService.requestHandler = { [unowned self] in + XCTAssertEqual($0.urlRequest.url, client.environment.url(for: .matches(.init(hashPrefix: "abc")))) + let data = try? JSONEncoder().encode(expectedResponse) + let response = HTTPURLResponse(url: $0.urlRequest.url!, statusCode: 200, httpVersion: nil, headerFields: nil)! + return .success(.init(data: data, httpResponse: response)) + } + + // When + let response = try await client.matches(forHashPrefix: "abc") + + // Then + XCTAssertEqual(response.matches, expectedResponse.matches) + } + + func testWhenHashPrefixesRequestFails_ErrorThrown() async throws { + // Given + let invalidRevision = -1 + mockService.requestHandler = { + // Simulate a failure or invalid request + let response = HTTPURLResponse(url: $0.urlRequest.url!, statusCode: 400, httpVersion: nil, headerFields: nil)! + return .success(.init(data: nil, httpResponse: response)) + } + + do { + let response = try await client.hashPrefixesChangeSet(for: .phishing, revision: invalidRevision) + XCTFail("Unexpected \(response) expected throw") + } catch { + } + } + + func testWhenFilterSetRequestFails_ErrorThrown() async throws { + // Given + let invalidRevision = -1 + mockService.requestHandler = { + // Simulate a failure or invalid request + let response = HTTPURLResponse(url: $0.urlRequest.url!, statusCode: 400, httpVersion: nil, headerFields: nil)! + return .success(.init(data: nil, httpResponse: response)) + } + + do { + let response = try await client.hashPrefixesChangeSet(for: .phishing, revision: invalidRevision) + XCTFail("Unexpected \(response) expected throw") + } catch { + } + } + + func testWhenMatchesRequestFails_ErrorThrown() async throws { + // Given + let invalidHashPrefix = "" + mockService.requestHandler = { + // Simulate a failure or invalid request + let response = HTTPURLResponse(url: $0.urlRequest.url!, statusCode: 400, httpVersion: nil, headerFields: nil)! + return .success(.init(data: nil, httpResponse: response)) + } + + do { + let response = try await client.matches(forHashPrefix: invalidHashPrefix) + XCTFail("Unexpected \(response) expected throw") + } catch { + } + } + +} diff --git a/Tests/MaliciousSiteProtectionTests/MaliciousSiteProtectionDataManagerTests.swift b/Tests/MaliciousSiteProtectionTests/MaliciousSiteProtectionDataManagerTests.swift new file mode 100644 index 000000000..5164f78d3 --- /dev/null +++ b/Tests/MaliciousSiteProtectionTests/MaliciousSiteProtectionDataManagerTests.swift @@ -0,0 +1,250 @@ +// +// MaliciousSiteProtectionDataManagerTests.swift +// +// Copyright © 2024 DuckDuckGo. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +import Foundation +import XCTest + +@testable import MaliciousSiteProtection + +class MaliciousSiteProtectionDataManagerTests: XCTestCase { + var embeddedDataProvider: MockMaliciousSiteProtectionEmbeddedDataProvider! + enum Constants { + static let hashPrefixesFileName = "phishingHashPrefixes.json" + static let filterSetFileName = "phishingFilterSet.json" + } + let datasetFiles: [String] = [Constants.hashPrefixesFileName, Constants.filterSetFileName] + var dataManager: MaliciousSiteProtection.DataManager! + var fileStore: MaliciousSiteProtection.FileStoring! + + override func setUp() async throws { + embeddedDataProvider = MockMaliciousSiteProtectionEmbeddedDataProvider() + fileStore = MockMaliciousSiteProtectionFileStore() + setUpDataManager() + } + + func setUpDataManager() { + dataManager = MaliciousSiteProtection.DataManager(fileStore: fileStore, embeddedDataProvider: embeddedDataProvider, fileNameProvider: { dataType in + switch dataType { + case .filterSet: Constants.filterSetFileName + case .hashPrefixSet: Constants.hashPrefixesFileName + } + }) + } + + override func tearDown() async throws { + embeddedDataProvider = nil + dataManager = nil + } + + func clearDatasets() { + for fileName in datasetFiles { + let emptyData = Data() + fileStore.write(data: emptyData, to: fileName) + } + } + + func testWhenNoDataSavedThenProviderDataReturned() async { + clearDatasets() + let expectedFilterSet = Set([Filter(hash: "some", regex: "some")]) + let expectedFilterDict = FilterDictionary(revision: 65, items: expectedFilterSet) + let expectedHashPrefix = Set(["sassa"]) + embeddedDataProvider.filterSet = expectedFilterSet + embeddedDataProvider.hashPrefixes = expectedHashPrefix + + let actualFilterSet = await dataManager.dataSet(for: .filterSet(threatKind: .phishing)) + let actualHashPrefix = await dataManager.dataSet(for: .hashPrefixes(threatKind: .phishing)) + + XCTAssertEqual(actualFilterSet, expectedFilterDict) + XCTAssertEqual(actualHashPrefix.set, expectedHashPrefix) + } + + func testWhenEmbeddedRevisionNewerThanOnDisk_ThenLoadEmbedded() async { + let encoder = JSONEncoder() + // On Disk Data Setup + let onDiskFilterSet = Set([Filter(hash: "other", regex: "other")]) + let filterSetData = try! encoder.encode(Array(onDiskFilterSet)) + let onDiskHashPrefix = Set(["faffa"]) + let hashPrefixData = try! encoder.encode(Array(onDiskHashPrefix)) + fileStore.write(data: filterSetData, to: Constants.filterSetFileName) + fileStore.write(data: hashPrefixData, to: Constants.hashPrefixesFileName) + + // Embedded Data Setup + embeddedDataProvider.embeddedRevision = 5 + let embeddedFilterSet = Set([Filter(hash: "some", regex: "some")]) + let embeddedFilterDict = FilterDictionary(revision: 5, items: embeddedFilterSet) + let embeddedHashPrefix = Set(["sassa"]) + embeddedDataProvider.filterSet = embeddedFilterSet + embeddedDataProvider.hashPrefixes = embeddedHashPrefix + + let actualFilterSet = await dataManager.dataSet(for: .filterSet(threatKind: .phishing)) + let actualHashPrefix = await dataManager.dataSet(for: .hashPrefixes(threatKind: .phishing)) + let actualFilterSetRevision = actualFilterSet.revision + let actualHashPrefixRevision = actualFilterSet.revision + + XCTAssertEqual(actualFilterSet, embeddedFilterDict) + XCTAssertEqual(actualHashPrefix.set, embeddedHashPrefix) + XCTAssertEqual(actualFilterSetRevision, 5) + XCTAssertEqual(actualHashPrefixRevision, 5) + } + + func testWhenEmbeddedRevisionOlderThanOnDisk_ThenDontLoadEmbedded() async { + // On Disk Data Setup + let onDiskFilterDict = FilterDictionary(revision: 6, items: [Filter(hash: "other", regex: "other")]) + let filterSetData = try! JSONEncoder().encode(onDiskFilterDict) + let onDiskHashPrefix = HashPrefixSet(revision: 6, items: ["faffa"]) + let hashPrefixData = try! JSONEncoder().encode(onDiskHashPrefix) + fileStore.write(data: filterSetData, to: Constants.filterSetFileName) + fileStore.write(data: hashPrefixData, to: Constants.hashPrefixesFileName) + + // Embedded Data Setup + embeddedDataProvider.embeddedRevision = 1 + let embeddedFilterSet = Set([Filter(hash: "some", regex: "some")]) + let embeddedHashPrefix = Set(["sassa"]) + embeddedDataProvider.filterSet = embeddedFilterSet + embeddedDataProvider.hashPrefixes = embeddedHashPrefix + + let actualFilterSet = await dataManager.dataSet(for: .filterSet(threatKind: .phishing)) + let actualHashPrefix = await dataManager.dataSet(for: .hashPrefixes(threatKind: .phishing)) + let actualFilterSetRevision = actualFilterSet.revision + let actualHashPrefixRevision = actualFilterSet.revision + + XCTAssertEqual(actualFilterSet, onDiskFilterDict) + XCTAssertEqual(actualHashPrefix, onDiskHashPrefix) + XCTAssertEqual(actualFilterSetRevision, 6) + XCTAssertEqual(actualHashPrefixRevision, 6) + } + + func testWhenStoredDataIsMalformed_ThenEmbeddedDataIsLoaded() async { + // On Disk Data Setup + fileStore.write(data: "fake".utf8data, to: Constants.filterSetFileName) + fileStore.write(data: "fake".utf8data, to: Constants.hashPrefixesFileName) + + // Embedded Data Setup + embeddedDataProvider.embeddedRevision = 1 + let embeddedFilterSet = Set([Filter(hash: "some", regex: "some")]) + let embeddedFilterDict = FilterDictionary(revision: 1, items: embeddedFilterSet) + let embeddedHashPrefix = Set(["sassa"]) + embeddedDataProvider.filterSet = embeddedFilterSet + embeddedDataProvider.hashPrefixes = embeddedHashPrefix + + let actualFilterSet = await dataManager.dataSet(for: .filterSet(threatKind: .phishing)) + let actualHashPrefix = await dataManager.dataSet(for: .hashPrefixes(threatKind: .phishing)) + let actualFilterSetRevision = actualFilterSet.revision + let actualHashPrefixRevision = actualFilterSet.revision + + XCTAssertEqual(actualFilterSet, embeddedFilterDict) + XCTAssertEqual(actualHashPrefix.set, embeddedHashPrefix) + XCTAssertEqual(actualFilterSetRevision, 1) + XCTAssertEqual(actualHashPrefixRevision, 1) + } + + func testWriteAndLoadData() async { + // Get and write data + let expectedHashPrefixes = Set(["aabb"]) + let expectedFilterSet = Set([Filter(hash: "dummyhash", regex: "dummyregex")]) + let expectedRevision = 65 + + await dataManager.store(HashPrefixSet(revision: expectedRevision, items: expectedHashPrefixes), for: .hashPrefixes(threatKind: .phishing)) + await dataManager.store(FilterDictionary(revision: expectedRevision, items: expectedFilterSet), for: .filterSet(threatKind: .phishing)) + + let actualFilterSet = await dataManager.dataSet(for: .filterSet(threatKind: .phishing)) + let actualHashPrefix = await dataManager.dataSet(for: .hashPrefixes(threatKind: .phishing)) + let actualFilterSetRevision = actualFilterSet.revision + let actualHashPrefixRevision = actualFilterSet.revision + + XCTAssertEqual(actualFilterSet, FilterDictionary(revision: expectedRevision, items: expectedFilterSet)) + XCTAssertEqual(actualHashPrefix.set, expectedHashPrefixes) + XCTAssertEqual(actualFilterSetRevision, 65) + XCTAssertEqual(actualHashPrefixRevision, 65) + + // Test reloading data + setUpDataManager() + + let reloadedFilterSet = await dataManager.dataSet(for: .filterSet(threatKind: .phishing)) + let reloadedHashPrefix = await dataManager.dataSet(for: .hashPrefixes(threatKind: .phishing)) + let reloadedFilterSetRevision = actualFilterSet.revision + let reloadedHashPrefixRevision = actualFilterSet.revision + + XCTAssertEqual(reloadedFilterSet, FilterDictionary(revision: expectedRevision, items: expectedFilterSet)) + XCTAssertEqual(reloadedHashPrefix.set, expectedHashPrefixes) + XCTAssertEqual(reloadedFilterSetRevision, 65) + XCTAssertEqual(reloadedHashPrefixRevision, 65) + } + + func testLazyLoadingDoesNotReturnStaleData() async { + clearDatasets() + + // Set up initial data + let initialFilterSet = Set([Filter(hash: "initial", regex: "initial")]) + let initialHashPrefixes = Set(["initialPrefix"]) + embeddedDataProvider.filterSet = initialFilterSet + embeddedDataProvider.hashPrefixes = initialHashPrefixes + + // Access the lazy-loaded properties to trigger loading + let loadedFilterSet = await dataManager.dataSet(for: .filterSet(threatKind: .phishing)) + let loadedHashPrefixes = await dataManager.dataSet(for: .hashPrefixes(threatKind: .phishing)) + + // Validate loaded data matches initial data + XCTAssertEqual(loadedFilterSet, FilterDictionary(revision: 65, items: initialFilterSet)) + XCTAssertEqual(loadedHashPrefixes.set, initialHashPrefixes) + + // Update in-memory data + let updatedFilterSet = Set([Filter(hash: "updated", regex: "updated")]) + let updatedHashPrefixes = Set(["updatedPrefix"]) + await dataManager.store(HashPrefixSet(revision: 1, items: updatedHashPrefixes), for: .hashPrefixes(threatKind: .phishing)) + await dataManager.store(FilterDictionary(revision: 1, items: updatedFilterSet), for: .filterSet(threatKind: .phishing)) + + let actualFilterSet = await dataManager.dataSet(for: .filterSet(threatKind: .phishing)) + let actualHashPrefix = await dataManager.dataSet(for: .hashPrefixes(threatKind: .phishing)) + let actualFilterSetRevision = actualFilterSet.revision + let actualHashPrefixRevision = actualFilterSet.revision + + XCTAssertEqual(actualFilterSet, FilterDictionary(revision: 1, items: updatedFilterSet)) + XCTAssertEqual(actualHashPrefix.set, updatedHashPrefixes) + XCTAssertEqual(actualFilterSetRevision, 1) + XCTAssertEqual(actualHashPrefixRevision, 1) + + // Test reloading data – embedded data should be returned as its revision is greater + setUpDataManager() + + let reloadedFilterSet = await dataManager.dataSet(for: .filterSet(threatKind: .phishing)) + let reloadedHashPrefix = await dataManager.dataSet(for: .hashPrefixes(threatKind: .phishing)) + let reloadedFilterSetRevision = actualFilterSet.revision + let reloadedHashPrefixRevision = actualFilterSet.revision + + XCTAssertEqual(reloadedFilterSet, FilterDictionary(revision: 65, items: initialFilterSet)) + XCTAssertEqual(reloadedHashPrefix.set, initialHashPrefixes) + XCTAssertEqual(reloadedFilterSetRevision, 1) + XCTAssertEqual(reloadedHashPrefixRevision, 1) + } + +} + +class MockMaliciousSiteProtectionFileStore: MaliciousSiteProtection.FileStoring { + + private var data: [String: Data] = [:] + + func write(data: Data, to filename: String) -> Bool { + self.data[filename] = data + return true + } + + func read(from filename: String) -> Data? { + return data[filename] + } +} diff --git a/Tests/MaliciousSiteProtectionTests/MaliciousSiteProtectionEmbeddedDataProviderTest.swift b/Tests/MaliciousSiteProtectionTests/MaliciousSiteProtectionEmbeddedDataProviderTest.swift new file mode 100644 index 000000000..1e3e0df40 --- /dev/null +++ b/Tests/MaliciousSiteProtectionTests/MaliciousSiteProtectionEmbeddedDataProviderTest.swift @@ -0,0 +1,62 @@ +// +// MaliciousSiteProtectionEmbeddedDataProviderTest.swift +// +// Copyright © 2024 DuckDuckGo. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +import Foundation +import XCTest + +@testable import MaliciousSiteProtection + +class MaliciousSiteProtectionEmbeddedDataProviderTest: XCTestCase { + + struct TestEmbeddedDataProvider: MaliciousSiteProtection.EmbeddedDataProviding { + func revision(for dataType: MaliciousSiteProtection.DataManager.StoredDataType) -> Int { + 0 + } + + func url(for dataType: MaliciousSiteProtection.DataManager.StoredDataType) -> URL { + switch dataType { + case .filterSet(let key): + Bundle.module.url(forResource: "\(key.threatKind)FilterSet", withExtension: "json")! + case .hashPrefixSet(let key): + Bundle.module.url(forResource: "\(key.threatKind)HashPrefixes", withExtension: "json")! + } + } + + func hash(for dataType: MaliciousSiteProtection.DataManager.StoredDataType) -> String { + switch dataType { + case .filterSet(let key): + switch key.threatKind { + case .phishing: + "4fd2868a4f264501ec175ab866504a2a96c8d21a3b5195b405a4a83b51eae504" + } + case .hashPrefixSet(let key): + switch key.threatKind { + case .phishing: + "21b047a9950fcaf86034a6b16181e18815cb8d276386d85c8977ca8c5f8aa05f" + } + } + } + } + + func testDataProviderLoadsJSON() { + let dataProvider = TestEmbeddedDataProvider() + let expectedFilter = Filter(hash: "e4753ddad954dafd4ff4ef67f82b3c1a2db6ef4a51bda43513260170e558bd13", regex: "(?i)^https?\\:\\/\\/privacy-test-pages\\.site(?:\\:(?:80|443))?\\/security\\/badware\\/phishing\\.html$") + XCTAssertTrue(dataProvider.loadDataSet(for: .filterSet(threatKind: .phishing)).contains(expectedFilter)) + XCTAssertTrue(dataProvider.loadDataSet(for: .hashPrefixes(threatKind: .phishing)).contains("012db806")) + } + +} diff --git a/Tests/PhishingDetectionTests/PhishingDetectionURLTests.swift b/Tests/MaliciousSiteProtectionTests/MaliciousSiteProtectionURLTests.swift similarity index 92% rename from Tests/PhishingDetectionTests/PhishingDetectionURLTests.swift rename to Tests/MaliciousSiteProtectionTests/MaliciousSiteProtectionURLTests.swift index ea0576369..8df462b3e 100644 --- a/Tests/PhishingDetectionTests/PhishingDetectionURLTests.swift +++ b/Tests/MaliciousSiteProtectionTests/MaliciousSiteProtectionURLTests.swift @@ -1,5 +1,5 @@ // -// PhishingDetectionURLTests.swift +// MaliciousSiteProtectionURLTests.swift // // Copyright © 2024 DuckDuckGo. All rights reserved. // @@ -18,9 +18,10 @@ import Foundation import XCTest -@testable import PhishingDetection -class PhishingDetectionURLTests: XCTestCase { +@testable import MaliciousSiteProtection + +class MaliciousSiteProtectionURLTests: XCTestCase { let testURLs = [ "http://www.example.com/security/badware/phishing.html#frags", diff --git a/Tests/MaliciousSiteProtectionTests/MaliciousSiteProtectionUpdateManagerTests.swift b/Tests/MaliciousSiteProtectionTests/MaliciousSiteProtectionUpdateManagerTests.swift new file mode 100644 index 000000000..8d46d5cf7 --- /dev/null +++ b/Tests/MaliciousSiteProtectionTests/MaliciousSiteProtectionUpdateManagerTests.swift @@ -0,0 +1,392 @@ +// +// MaliciousSiteProtectionUpdateManagerTests.swift +// +// Copyright © 2024 DuckDuckGo. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +import Clocks +import Common +import Foundation +import XCTest + +@testable import MaliciousSiteProtection + +class MaliciousSiteProtectionUpdateManagerTests: XCTestCase { + + var updateManager: MaliciousSiteProtection.UpdateManager! + var dataManager: MockMaliciousSiteProtectionDataManager! + var apiClient: MaliciousSiteProtection.APIClient.Mockable! + var updateIntervalProvider: UpdateManager.UpdateIntervalProvider! + var clock: TestClock! + var willSleep: ((TimeInterval) -> Void)? + var updateTask: Task? + + override func setUp() async throws { + apiClient = MockMaliciousSiteProtectionAPIClient() + dataManager = MockMaliciousSiteProtectionDataManager() + clock = TestClock() + + let clockSleeper = Sleeper(clock: clock) + let reportingSleeper = Sleeper { + self.willSleep?($0) + try await clockSleeper.sleep(for: $0) + } + + updateManager = MaliciousSiteProtection.UpdateManager(apiClient: apiClient, dataManager: dataManager, sleeper: reportingSleeper, updateIntervalProvider: { self.updateIntervalProvider($0) }) + } + + override func tearDown() async throws { + updateManager = nil + dataManager = nil + apiClient = nil + updateIntervalProvider = nil + updateTask?.cancel() + } + + func testUpdateHashPrefixes() async { + await updateManager.updateData(for: .hashPrefixes(threatKind: .phishing)) + let dataSet = await dataManager.dataSet(for: .hashPrefixes(threatKind: .phishing)) + XCTAssertEqual(dataSet, HashPrefixSet(revision: 1, items: [ + "aa00bb11", + "bb00cc11", + "cc00dd11", + "dd00ee11", + "a379a6f6" + ])) + } + + func testUpdateFilterSet() async { + await updateManager.updateData(for: .filterSet(threatKind: .phishing)) + let dataSet = await dataManager.dataSet(for: .filterSet(threatKind: .phishing)) + XCTAssertEqual(dataSet, FilterDictionary(revision: 1, items: [ + Filter(hash: "testhash1", regex: ".*example.*"), + Filter(hash: "testhash2", regex: ".*test.*") + ])) + } + + func testRevision1AddsAndDeletesData() async { + let expectedFilterSet: Set = [ + Filter(hash: "testhash2", regex: ".*test.*"), + Filter(hash: "testhash3", regex: ".*test.*") + ] + let expectedHashPrefixes: Set = [ + "aa00bb11", + "bb00cc11", + "a379a6f6", + "93e2435e" + ] + + // revision 0 -> 1 + await updateManager.updateData(for: .filterSet(threatKind: .phishing)) + await updateManager.updateData(for: .hashPrefixes(threatKind: .phishing)) + + // revision 1 -> 2 + await updateManager.updateData(for: .filterSet(threatKind: .phishing)) + await updateManager.updateData(for: .hashPrefixes(threatKind: .phishing)) + + let hashPrefixes = await dataManager.dataSet(for: .hashPrefixes(threatKind: .phishing)) + let filterSet = await dataManager.dataSet(for: .filterSet(threatKind: .phishing)) + + XCTAssertEqual(hashPrefixes, HashPrefixSet(revision: 2, items: expectedHashPrefixes), "Hash prefixes should match the expected set after update.") + XCTAssertEqual(filterSet, FilterDictionary(revision: 2, items: expectedFilterSet), "Filter set should match the expected set after update.") + } + + func testRevision2AddsAndDeletesData() async { + let expectedFilterSet: Set = [ + Filter(hash: "testhash4", regex: ".*test.*"), + Filter(hash: "testhash2", regex: ".*test1.*"), + Filter(hash: "testhash1", regex: ".*example.*"), + Filter(hash: "testhash3", regex: ".*test3.*"), + ] + let expectedHashPrefixes: Set = [ + "aa00bb11", + "a379a6f6", + "c0be0d0a6", + "dd00ee11", + "cc00dd11" + ] + + // Save revision and update the filter set and hash prefixes + await dataManager.store(FilterDictionary(revision: 2, items: [ + Filter(hash: "testhash1", regex: ".*example.*"), + Filter(hash: "testhash2", regex: ".*test.*"), + Filter(hash: "testhash2", regex: ".*test1.*"), + Filter(hash: "testhash3", regex: ".*test3.*"), + ]), for: .filterSet(threatKind: .phishing)) + await dataManager.store(HashPrefixSet(revision: 2, items: [ + "aa00bb11", + "bb00cc11", + "cc00dd11", + "dd00ee11", + "a379a6f6" + ]), for: .hashPrefixes(threatKind: .phishing)) + + await updateManager.updateData(for: .filterSet(threatKind: .phishing)) + await updateManager.updateData(for: .hashPrefixes(threatKind: .phishing)) + + let hashPrefixes = await dataManager.dataSet(for: .hashPrefixes(threatKind: .phishing)) + let filterSet = await dataManager.dataSet(for: .filterSet(threatKind: .phishing)) + + XCTAssertEqual(hashPrefixes, HashPrefixSet(revision: 3, items: expectedHashPrefixes), "Hash prefixes should match the expected set after update.") + XCTAssertEqual(filterSet, FilterDictionary(revision: 3, items: expectedFilterSet), "Filter set should match the expected set after update.") + } + + func testRevision3AddsAndDeletesNothing() async { + let expectedFilterSet: Set = [] + let expectedHashPrefixes: Set = [] + + // Save revision and update the filter set and hash prefixes + await dataManager.store(FilterDictionary(revision: 3, items: []), for: .filterSet(threatKind: .phishing)) + await dataManager.store(HashPrefixSet(revision: 3, items: []), for: .hashPrefixes(threatKind: .phishing)) + + await updateManager.updateData(for: .filterSet(threatKind: .phishing)) + await updateManager.updateData(for: .hashPrefixes(threatKind: .phishing)) + + let hashPrefixes = await dataManager.dataSet(for: .hashPrefixes(threatKind: .phishing)) + let filterSet = await dataManager.dataSet(for: .filterSet(threatKind: .phishing)) + + XCTAssertEqual(hashPrefixes, HashPrefixSet(revision: 3, items: expectedHashPrefixes), "Hash prefixes should match the expected set after update.") + XCTAssertEqual(filterSet, FilterDictionary(revision: 3, items: expectedFilterSet), "Filter set should match the expected set after update.") + } + + func testRevision4AddsAndDeletesData() async { + let expectedFilterSet: Set = [ + Filter(hash: "testhash5", regex: ".*test.*") + ] + let expectedHashPrefixes: Set = [ + "a379a6f6", + ] + + // Save revision and update the filter set and hash prefixes + await dataManager.store(FilterDictionary(revision: 4, items: []), for: .filterSet(threatKind: .phishing)) + await dataManager.store(HashPrefixSet(revision: 4, items: []), for: .hashPrefixes(threatKind: .phishing)) + + await updateManager.updateData(for: .filterSet(threatKind: .phishing)) + await updateManager.updateData(for: .hashPrefixes(threatKind: .phishing)) + + let hashPrefixes = await dataManager.dataSet(for: .hashPrefixes(threatKind: .phishing)) + let filterSet = await dataManager.dataSet(for: .filterSet(threatKind: .phishing)) + + XCTAssertEqual(hashPrefixes, HashPrefixSet(revision: 5, items: expectedHashPrefixes), "Hash prefixes should match the expected set after update.") + XCTAssertEqual(filterSet, FilterDictionary(revision: 5, items: expectedFilterSet), "Filter set should match the expected set after update.") + } + + func testRevision5replacesData() async { + let expectedFilterSet: Set = [ + Filter(hash: "testhash6", regex: ".*test6.*") + ] + let expectedHashPrefixes: Set = [ + "aa55aa55" + ] + + // Save revision and update the filter set and hash prefixes + await dataManager.store(FilterDictionary(revision: 5, items: [ + Filter(hash: "testhash2", regex: ".*test.*"), + Filter(hash: "testhash1", regex: ".*example.*"), + Filter(hash: "testhash5", regex: ".*test.*") + ]), for: .filterSet(threatKind: .phishing)) + await dataManager.store(HashPrefixSet(revision: 5, items: [ + "a379a6f6", + "dd00ee11", + "cc00dd11", + "bb00cc11" + ]), for: .hashPrefixes(threatKind: .phishing)) + + await updateManager.updateData(for: .filterSet(threatKind: .phishing)) + await updateManager.updateData(for: .hashPrefixes(threatKind: .phishing)) + + let hashPrefixes = await dataManager.dataSet(for: .hashPrefixes(threatKind: .phishing)) + let filterSet = await dataManager.dataSet(for: .filterSet(threatKind: .phishing)) + + XCTAssertEqual(hashPrefixes, HashPrefixSet(revision: 6, items: expectedHashPrefixes), "Hash prefixes should match the expected set after update.") + XCTAssertEqual(filterSet, FilterDictionary(revision: 6, items: expectedFilterSet), "Filter set should match the expected set after update.") + } + + func testWhenPeriodicUpdatesStart_dataSetsAreUpdated() async throws { + self.updateIntervalProvider = { _ in 1 } + + let eHashPrefixesUpdated = expectation(description: "Hash prefixes updated") + let c1 = await dataManager.publisher(for: .hashPrefixes(threatKind: .phishing)).dropFirst().sink { data in + eHashPrefixesUpdated.fulfill() + } + let eFilterSetUpdated = expectation(description: "Filter set updated") + let c2 = await dataManager.publisher(for: .filterSet(threatKind: .phishing)).dropFirst().sink { data in + eFilterSetUpdated.fulfill() + } + + updateTask = updateManager.startPeriodicUpdates() + await Task.megaYield(count: 10) + + // expect initial update run instantly + await fulfillment(of: [eHashPrefixesUpdated, eFilterSetUpdated], timeout: 1) + + withExtendedLifetime((c1, c2)) {} + } + + func testWhenPeriodicUpdatesAreEnabled_dataSetsAreUpdatedContinuously() async throws { + // Start periodic updates + self.updateIntervalProvider = { dataType in + switch dataType { + case .filterSet: return 2 + case .hashPrefixSet: return 1 + } + } + + let hashPrefixUpdateExpectations = [ + XCTestExpectation(description: "Hash prefixes rev.1 update received"), + XCTestExpectation(description: "Hash prefixes rev.2 update received"), + XCTestExpectation(description: "Hash prefixes rev.3 update received"), + ] + let filterSetUpdateExpectations = [ + XCTestExpectation(description: "Filter set rev.1 update received"), + XCTestExpectation(description: "Filter set rev.2 update received"), + XCTestExpectation(description: "Filter set rev.3 update received"), + ] + let hashPrefixSleepExpectations = [ + XCTestExpectation(description: "HP Will Sleep 1"), + XCTestExpectation(description: "HP Will Sleep 2"), + XCTestExpectation(description: "HP Will Sleep 3"), + ] + let filterSetSleepExpectations = [ + XCTestExpectation(description: "FS Will Sleep 1"), + XCTestExpectation(description: "FS Will Sleep 2"), + XCTestExpectation(description: "FS Will Sleep 3"), + ] + + let c1 = await dataManager.publisher(for: .hashPrefixes(threatKind: .phishing)).dropFirst().sink { data in + hashPrefixUpdateExpectations[data.revision - 1].fulfill() + } + let c2 = await dataManager.publisher(for: .filterSet(threatKind: .phishing)).dropFirst().sink { data in + filterSetUpdateExpectations[data.revision - 1].fulfill() + } + var hashPrefixSleepIndex = 0 + var filterSetSleepIndex = 0 + self.willSleep = { interval in + if interval == 1 { + hashPrefixSleepExpectations[safe: hashPrefixSleepIndex]?.fulfill() + hashPrefixSleepIndex += 1 + } else { + filterSetSleepExpectations[safe: filterSetSleepIndex]?.fulfill() + filterSetSleepIndex += 1 + } + } + + // expect initial hashPrefixes update run instantly + updateTask = updateManager.startPeriodicUpdates() + await fulfillment(of: [hashPrefixUpdateExpectations[0], hashPrefixSleepExpectations[0], filterSetUpdateExpectations[0], filterSetSleepExpectations[0]], timeout: 1) + + // Advance the clock by 1 seconds + await self.clock.advance(by: .seconds(1)) + // expect to receive v.2 update for hashPrefixes + await fulfillment(of: [hashPrefixUpdateExpectations[1], hashPrefixSleepExpectations[1]], timeout: 1) + + // Advance the clock by 1 seconds + await self.clock.advance(by: .seconds(1)) + // expect to receive v.3 update for hashPrefixes and v.2 update for filterSet + await fulfillment(of: [hashPrefixUpdateExpectations[2], hashPrefixSleepExpectations[2], filterSetUpdateExpectations[1], filterSetSleepExpectations[1]], timeout: 1) // + + // Advance the clock by 1 seconds + await self.clock.advance(by: .seconds(2)) + // expect to receive v.3 update for filterSet and no update for hashPrefixes (no v.3 updates in the mock) + await fulfillment(of: [filterSetUpdateExpectations[2], filterSetSleepExpectations[2]], timeout: 1) // + + withExtendedLifetime((c1, c2)) {} + } + + func testWhenPeriodicUpdatesAreDisabled_noDataSetsAreUpdated() async throws { + // Start periodic updates + self.updateIntervalProvider = { dataType in + switch dataType { + case .filterSet: return nil // Set update interval to nil for FilterSet + case .hashPrefixSet: return 1 + } + } + + let expectations = [ + XCTestExpectation(description: "Hash prefixes rev.1 update received"), + XCTestExpectation(description: "Hash prefixes rev.2 update received"), + XCTestExpectation(description: "Hash prefixes rev.3 update received"), + ] + let c1 = await dataManager.publisher(for: .hashPrefixes(threatKind: .phishing)).dropFirst().sink { data in + expectations[data.revision - 1].fulfill() + } + // data for FilterSet should not be updated + let c2 = await dataManager.publisher(for: .filterSet(threatKind: .phishing)).dropFirst().sink { data in + XCTFail("Unexpected filter set update received: \(data)") + } + // synchronize Task threads to advance the Test Clock when the updated Task is sleeping, + // otherwise we‘ll eventually advance the clock before the sleep and get hung. + var sleepIndex = 0 + let sleepExpectations = [ + XCTestExpectation(description: "Will Sleep 1"), + XCTestExpectation(description: "Will Sleep 2"), + XCTestExpectation(description: "Will Sleep 3"), + ] + self.willSleep = { _ in + sleepExpectations[sleepIndex].fulfill() + sleepIndex += 1 + } + + // expect initial hashPrefixes update run instantly + updateTask = updateManager.startPeriodicUpdates() + await fulfillment(of: [expectations[0], sleepExpectations[0]], timeout: 1) + + // Advance the clock by 1 seconds + await self.clock.advance(by: .seconds(1)) + // expect to receive v.2 update for hashPrefixes + await fulfillment(of: [expectations[1], sleepExpectations[1]], timeout: 1) + + // Advance the clock by 1 seconds + await self.clock.advance(by: .seconds(1)) + // expect to receive v.3 update for hashPrefixes + await fulfillment(of: [expectations[2], sleepExpectations[2]], timeout: 1) + + withExtendedLifetime((c1, c2)) {} + } + + func testWhenPeriodicUpdatesAreCancelled_noFurtherUpdatesReceived() async throws { + // Start periodic updates + self.updateIntervalProvider = { _ in 1 } + updateTask = updateManager.startPeriodicUpdates() + + // Wait for the initial update + try await withTimeout(1) { [self] in + for await _ in await dataManager.publisher(for: .filterSet(threatKind: .phishing)).first(where: { $0.revision == 1 }).values {} + for await _ in await dataManager.publisher(for: .filterSet(threatKind: .phishing)).first(where: { $0.revision == 1 }).values {} + } + + // Cancel the update task + updateTask!.cancel() + + // Reset expectations for further updates + let c = await dataManager.$store.dropFirst().sink { data in + XCTFail("Unexpected data update received: \(data)") + } + + // Advance the clock to check for further updates + await self.clock.advance(by: .seconds(2)) + await clock.run() + await Task.megaYield(count: 10) + + // Verify that the data sets have not been updated further + let hashPrefixes = await dataManager.dataSet(for: .hashPrefixes(threatKind: .phishing)) + let filterSet = await dataManager.dataSet(for: .filterSet(threatKind: .phishing)) + XCTAssertEqual(hashPrefixes.revision, 1) // Expecting revision to remain 1 + XCTAssertEqual(filterSet.revision, 1) // Expecting revision to remain 1 + + withExtendedLifetime(c) {} + } + +} diff --git a/Tests/PhishingDetectionTests/Mocks/EventMappingMock.swift b/Tests/MaliciousSiteProtectionTests/Mocks/MockEventMapping.swift similarity index 80% rename from Tests/PhishingDetectionTests/Mocks/EventMappingMock.swift rename to Tests/MaliciousSiteProtectionTests/Mocks/MockEventMapping.swift index 7c736c7e3..1edbb98a2 100644 --- a/Tests/PhishingDetectionTests/Mocks/EventMappingMock.swift +++ b/Tests/MaliciousSiteProtectionTests/Mocks/MockEventMapping.swift @@ -1,5 +1,5 @@ // -// EventMappingMock.swift +// MockEventMapping.swift // // Copyright © 2024 DuckDuckGo. All rights reserved. // @@ -15,13 +15,14 @@ // See the License for the specific language governing permissions and // limitations under the License. // -import Foundation + import Common -import PhishingDetection +import Foundation +import MaliciousSiteProtection import PixelKit -public class MockEventMapping: EventMapping { - static var events: [PhishingDetectionEvents] = [] +public class MockEventMapping: EventMapping { + static var events: [MaliciousSiteProtection.Event] = [] static var clientSideHitParam: String? static var errorParam: Error? @@ -39,7 +40,7 @@ public class MockEventMapping: EventMapping { } } - override init(mapping: @escaping EventMapping.Mapping) { + override init(mapping: @escaping EventMapping.Mapping) { fatalError("Use init()") } } diff --git a/Tests/MaliciousSiteProtectionTests/Mocks/MockMaliciousSiteProtectionAPIClient.swift b/Tests/MaliciousSiteProtectionTests/Mocks/MockMaliciousSiteProtectionAPIClient.swift new file mode 100644 index 000000000..4f2062edd --- /dev/null +++ b/Tests/MaliciousSiteProtectionTests/Mocks/MockMaliciousSiteProtectionAPIClient.swift @@ -0,0 +1,103 @@ +// +// MockMaliciousSiteProtectionAPIClient.swift +// +// Copyright © 2024 DuckDuckGo. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +import Foundation +@testable import MaliciousSiteProtection + +class MockMaliciousSiteProtectionAPIClient: MaliciousSiteProtection.APIClient.Mockable { + var updateHashPrefixesCalled: ((Int) -> Void)? + var updateFilterSetsCalled: ((Int) -> Void)? + + var filterRevisions: [Int: APIClient.Response.FiltersChangeSet] = [ + 0: .init(insert: [ + Filter(hash: "testhash1", regex: ".*example.*"), + Filter(hash: "testhash2", regex: ".*test.*") + ], delete: [], revision: 1, replace: false), + 1: .init(insert: [ + Filter(hash: "testhash3", regex: ".*test.*") + ], delete: [ + Filter(hash: "testhash1", regex: ".*example.*"), + ], revision: 2, replace: false), + 2: .init(insert: [ + Filter(hash: "testhash4", regex: ".*test.*") + ], delete: [ + Filter(hash: "testhash2", regex: ".*test.*"), + ], revision: 3, replace: false), + 4: .init(insert: [ + Filter(hash: "testhash5", regex: ".*test.*") + ], delete: [ + Filter(hash: "testhash3", regex: ".*test.*"), + ], revision: 5, replace: false), + 5: .init(insert: [ + Filter(hash: "testhash6", regex: ".*test6.*") + ], delete: [ + Filter(hash: "testhash3", regex: ".*test.*"), + ], revision: 6, replace: true), + ] + + private var hashPrefixRevisions: [Int: APIClient.Response.HashPrefixesChangeSet] = [ + 0: .init(insert: [ + "aa00bb11", + "bb00cc11", + "cc00dd11", + "dd00ee11", + "a379a6f6" + ], delete: [], revision: 1, replace: false), + 1: .init(insert: ["93e2435e"], delete: [ + "cc00dd11", + "dd00ee11", + ], revision: 2, replace: false), + 2: .init(insert: ["c0be0d0a6"], delete: [ + "bb00cc11", + ], revision: 3, replace: false), + 4: .init(insert: ["a379a6f6"], delete: [ + "aa00bb11", + ], revision: 5, replace: false), + 5: .init(insert: ["aa55aa55"], delete: [ + "ffgghhzz", + ], revision: 6, replace: true), + ] + + func load(_ requestConfig: Request) async throws -> Request.Response where Request: APIClient.Request { + switch requestConfig.requestType { + case .hashPrefixSet(let configuration): + return _hashPrefixesChangeSet(for: configuration.threatKind, revision: configuration.revision ?? 0) as! Request.Response + case .filterSet(let configuration): + return _filtersChangeSet(for: configuration.threatKind, revision: configuration.revision ?? 0) as! Request.Response + case .matches(let configuration): + return _matches(forHashPrefix: configuration.hashPrefix) as! Request.Response + } + } + func _filtersChangeSet(for threatKind: MaliciousSiteProtection.ThreatKind, revision: Int) -> MaliciousSiteProtection.APIClient.Response.FiltersChangeSet { + updateFilterSetsCalled?(revision) + return filterRevisions[revision] ?? .init(insert: [], delete: [], revision: revision, replace: false) + } + + func _hashPrefixesChangeSet(for threatKind: MaliciousSiteProtection.ThreatKind, revision: Int) -> MaliciousSiteProtection.APIClient.Response.HashPrefixesChangeSet { + updateHashPrefixesCalled?(revision) + return hashPrefixRevisions[revision] ?? .init(insert: [], delete: [], revision: revision, replace: false) + } + + func _matches(forHashPrefix hashPrefix: String) -> APIClient.Response.Matches { + .init(matches: [ + Match(hostname: "example.com", url: "https://example.com/mal", regex: ".*", hash: "a379a6f6eeafb9a55e378c118034e2751e682fab9f2d30ab13d2125586ce1947", category: nil), + Match(hostname: "test.com", url: "https://test.com/mal", regex: ".*test.*", hash: "aa00bb11aa00cc11bb00cc11", category: nil) + ]) + } + +} diff --git a/Tests/MaliciousSiteProtectionTests/Mocks/MockMaliciousSiteProtectionDataManager.swift b/Tests/MaliciousSiteProtectionTests/Mocks/MockMaliciousSiteProtectionDataManager.swift new file mode 100644 index 000000000..1a67ad329 --- /dev/null +++ b/Tests/MaliciousSiteProtectionTests/Mocks/MockMaliciousSiteProtectionDataManager.swift @@ -0,0 +1,40 @@ +// +// MockMaliciousSiteProtectionDataManager.swift +// +// Copyright © 2024 DuckDuckGo. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +import Combine +import Foundation +@testable import MaliciousSiteProtection + +actor MockMaliciousSiteProtectionDataManager: MaliciousSiteProtection.DataManaging { + + @Published var store = [MaliciousSiteProtection.DataManager.StoredDataType: Any]() + func publisher(for key: DataKey) -> AnyPublisher where DataKey: MaliciousSiteProtection.MaliciousSiteDataKey { + $store.map { $0[key.dataType] as? DataKey.DataSet ?? .init(revision: 0, items: []) } + .removeDuplicates() + .eraseToAnyPublisher() + } + + public func dataSet(for key: DataKey) -> DataKey.DataSet where DataKey: MaliciousSiteProtection.MaliciousSiteDataKey { + return store[key.dataType] as? DataKey.DataSet ?? .init(revision: 0, items: []) + } + + func store(_ dataSet: DataKey.DataSet, for key: DataKey) async where DataKey: MaliciousSiteProtection.MaliciousSiteDataKey { + store[key.dataType] = dataSet + } + +} diff --git a/Tests/MaliciousSiteProtectionTests/Mocks/MockMaliciousSiteProtectionEmbeddedDataProvider.swift b/Tests/MaliciousSiteProtectionTests/Mocks/MockMaliciousSiteProtectionEmbeddedDataProvider.swift new file mode 100644 index 000000000..10a3e2643 --- /dev/null +++ b/Tests/MaliciousSiteProtectionTests/Mocks/MockMaliciousSiteProtectionEmbeddedDataProvider.swift @@ -0,0 +1,81 @@ +// +// MockMaliciousSiteProtectionEmbeddedDataProvider.swift +// +// Copyright © 2024 DuckDuckGo. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +import Foundation +@testable import MaliciousSiteProtection + +final class MockMaliciousSiteProtectionEmbeddedDataProvider: MaliciousSiteProtection.EmbeddedDataProviding { + var embeddedRevision: Int = 65 + var loadHashPrefixesCalled: Bool = false + var loadFilterSetCalled: Bool = true + var hashPrefixes: Set = [] { + didSet { + hashPrefixesData = try! JSONEncoder().encode(hashPrefixes) + } + } + var hashPrefixesData: Data! + + var filterSet: Set = [] { + didSet { + filterSetData = try! JSONEncoder().encode(filterSet) + } + } + var filterSetData: Data! + + init() { + hashPrefixes = Set(["aabb"]) + filterSet = Set([Filter(hash: "dummyhash", regex: "dummyregex")]) + } + + func revision(for detectionKind: MaliciousSiteProtection.DataManager.StoredDataType) -> Int { + embeddedRevision + } + + func url(for dataType: MaliciousSiteProtection.DataManager.StoredDataType) -> URL { + switch dataType { + case .filterSet: + self.loadFilterSetCalled = true + return URL(string: "filterSet")! + case .hashPrefixSet: + self.loadHashPrefixesCalled = true + return URL(string: "hashPrefixSet")! + } + } + + func hash(for dataType: MaliciousSiteProtection.DataManager.StoredDataType) -> String { + let url = url(for: dataType) + let data = try! data(withContentsOf: url) + let sha = data.sha256 + return sha + } + + func data(withContentsOf url: URL) throws -> Data { + let data: Data + switch url.absoluteString { + case "filterSet": + self.loadFilterSetCalled = true + return filterSetData + case "hashPrefixSet": + self.loadHashPrefixesCalled = true + return hashPrefixesData + default: + fatalError("Unexpected url \(url.absoluteString)") + } + } + +} diff --git a/Tests/PhishingDetectionTests/Mocks/PhishingDetectionUpdateManagerMock.swift b/Tests/MaliciousSiteProtectionTests/Mocks/MockPhishingDetectionUpdateManager.swift similarity index 59% rename from Tests/PhishingDetectionTests/Mocks/PhishingDetectionUpdateManagerMock.swift rename to Tests/MaliciousSiteProtectionTests/Mocks/MockPhishingDetectionUpdateManager.swift index d5ca12559..b49eac588 100644 --- a/Tests/PhishingDetectionTests/Mocks/PhishingDetectionUpdateManagerMock.swift +++ b/Tests/MaliciousSiteProtectionTests/Mocks/MockPhishingDetectionUpdateManager.swift @@ -1,5 +1,5 @@ // -// PhishingDetectionUpdateManagerMock.swift +// MockPhishingDetectionUpdateManager.swift // // Copyright © 2024 DuckDuckGo. All rights reserved. // @@ -17,27 +17,40 @@ // import Foundation -import PhishingDetection +@testable import MaliciousSiteProtection + +class MockPhishingDetectionUpdateManager: MaliciousSiteProtection.UpdateManaging { -public class MockPhishingDetectionUpdateManager: PhishingDetectionUpdateManaging { var didUpdateFilterSet = false var didUpdateHashPrefixes = false + var startPeriodicUpdatesCalled = false var completionHandler: (() -> Void)? - public func updateFilterSet() async { + func updateData(for key: some MaliciousSiteProtection.MaliciousSiteDataKey) async { + switch key.dataType { + case .filterSet: await updateFilterSet() + case .hashPrefixSet: await updateHashPrefixes() + } + } + + func updateFilterSet() async { didUpdateFilterSet = true checkCompletion() } - public func updateHashPrefixes() async { + func updateHashPrefixes() async { didUpdateHashPrefixes = true checkCompletion() } - private func checkCompletion() { + func checkCompletion() { if didUpdateFilterSet && didUpdateHashPrefixes { completionHandler?() } } + public func startPeriodicUpdates() -> Task { + startPeriodicUpdatesCalled = true + return Task {} + } } diff --git a/Tests/PhishingDetectionTests/Resources/filterSet.json b/Tests/MaliciousSiteProtectionTests/Resources/phishingFilterSet.json similarity index 100% rename from Tests/PhishingDetectionTests/Resources/filterSet.json rename to Tests/MaliciousSiteProtectionTests/Resources/phishingFilterSet.json diff --git a/Tests/PhishingDetectionTests/Resources/hashPrefixes.json b/Tests/MaliciousSiteProtectionTests/Resources/phishingHashPrefixes.json similarity index 100% rename from Tests/PhishingDetectionTests/Resources/hashPrefixes.json rename to Tests/MaliciousSiteProtectionTests/Resources/phishingHashPrefixes.json diff --git a/Tests/NavigationTests/Helpers/NavigationResponderMock.swift b/Tests/NavigationTests/Helpers/NavigationResponderMock.swift index d39a6ee44..fda1b2805 100644 --- a/Tests/NavigationTests/Helpers/NavigationResponderMock.swift +++ b/Tests/NavigationTests/Helpers/NavigationResponderMock.swift @@ -374,7 +374,6 @@ class NavigationResponderMock: NavigationResponder { var onDidTerminate: (@MainActor (WKProcessTerminationReason?) -> Void)? func webContentProcessDidTerminate(with reason: WKProcessTerminationReason?) { - let event = append(.didTerminate(reason)) onDidTerminate?(reason) } diff --git a/Tests/NetworkingTests/v2/APIRequestV2Tests.swift b/Tests/NetworkingTests/v2/APIRequestV2Tests.swift index 59eeadebb..4ec1b8b59 100644 --- a/Tests/NetworkingTests/v2/APIRequestV2Tests.swift +++ b/Tests/NetworkingTests/v2/APIRequestV2Tests.swift @@ -41,21 +41,18 @@ final class APIRequestV2Tests: XCTestCase { cachePolicy: cachePolicy, responseConstraints: constraints) - guard let urlRequest = apiRequest?.urlRequest else { - XCTFail("Nil URLRequest") - return - } + let urlRequest = apiRequest.urlRequest XCTAssertEqual(urlRequest.url?.host(), url.host()) XCTAssertEqual(urlRequest.httpMethod, method.rawValue) let urlComponents = URLComponents(string: urlRequest.url!.absoluteString)! - XCTAssertTrue(urlComponents.queryItems!.contains(queryItems.toURLQueryItems())) + XCTAssertTrue(urlComponents.queryItems!.contains(URLQueryItem(name: "key", value: "value"))) XCTAssertEqual(urlRequest.allHTTPHeaderFields, headers.httpHeaders) XCTAssertEqual(urlRequest.httpBody, body) - XCTAssertEqual(apiRequest?.timeoutInterval, timeoutInterval) + XCTAssertEqual(apiRequest.timeoutInterval, timeoutInterval) XCTAssertEqual(urlRequest.cachePolicy, cachePolicy) - XCTAssertEqual(apiRequest?.responseConstraints, constraints) + XCTAssertEqual(apiRequest.responseConstraints, constraints) } func testURLRequestGeneration() { @@ -75,16 +72,16 @@ final class APIRequestV2Tests: XCTestCase { timeoutInterval: timeoutInterval, cachePolicy: cachePolicy) - let urlComponents = URLComponents(string: apiRequest!.urlRequest.url!.absoluteString)! - XCTAssertTrue(urlComponents.queryItems!.contains(queryItems.toURLQueryItems())) + let urlComponents = URLComponents(string: apiRequest.urlRequest.url!.absoluteString)! + XCTAssertTrue(urlComponents.queryItems!.contains(URLQueryItem(name: "key", value: "value"))) XCTAssertNotNil(apiRequest) - XCTAssertEqual(apiRequest?.urlRequest.url?.absoluteString, "https://www.example.com?key=value") - XCTAssertEqual(apiRequest?.urlRequest.httpMethod, method.rawValue) - XCTAssertEqual(apiRequest?.urlRequest.allHTTPHeaderFields, headers.httpHeaders) - XCTAssertEqual(apiRequest?.urlRequest.httpBody, body) - XCTAssertEqual(apiRequest?.urlRequest.timeoutInterval, timeoutInterval) - XCTAssertEqual(apiRequest?.urlRequest.cachePolicy, cachePolicy) + XCTAssertEqual(apiRequest.urlRequest.url?.absoluteString, "https://www.example.com?key=value") + XCTAssertEqual(apiRequest.urlRequest.httpMethod, method.rawValue) + XCTAssertEqual(apiRequest.urlRequest.allHTTPHeaderFields, headers.httpHeaders) + XCTAssertEqual(apiRequest.urlRequest.httpBody, body) + XCTAssertEqual(apiRequest.urlRequest.timeoutInterval, timeoutInterval) + XCTAssertEqual(apiRequest.urlRequest.cachePolicy, cachePolicy) } func testDefaultValues() { @@ -92,16 +89,13 @@ final class APIRequestV2Tests: XCTestCase { let apiRequest = APIRequestV2(url: url) let headers = APIRequestV2.HeadersV2() - guard let urlRequest = apiRequest?.urlRequest else { - XCTFail("Nil URLRequest") - return - } + let urlRequest = apiRequest.urlRequest XCTAssertEqual(urlRequest.httpMethod, HTTPRequestMethod.get.rawValue) XCTAssertEqual(urlRequest.timeoutInterval, 60.0) XCTAssertEqual(headers.httpHeaders, urlRequest.allHTTPHeaderFields) XCTAssertNil(urlRequest.httpBody) XCTAssertEqual(urlRequest.cachePolicy.rawValue, 0) - XCTAssertNil(apiRequest?.responseConstraints) + XCTAssertNil(apiRequest.responseConstraints) } func testAllowedQueryReservedCharacters() { @@ -112,9 +106,10 @@ final class APIRequestV2Tests: XCTestCase { queryItems: queryItems, allowedQueryReservedCharacters: CharacterSet(charactersIn: ",")) - let urlString = apiRequest!.urlRequest.url!.absoluteString - XCTAssertTrue(urlString == "https://www.example.com?k%2523e,y=val%2523ue") + let urlString = apiRequest.urlRequest.url!.absoluteString + XCTAssertEqual(urlString, "https://www.example.com?k%23e,y=val%23ue") + let urlComponents = URLComponents(string: urlString)! - XCTAssertTrue(urlComponents.queryItems?.count == 1) + XCTAssertEqual(urlComponents.queryItems?.count, 1) } } diff --git a/Tests/NetworkingTests/v2/APIServiceTests.swift b/Tests/NetworkingTests/v2/APIServiceTests.swift index 9cae44323..730d6afbb 100644 --- a/Tests/NetworkingTests/v2/APIServiceTests.swift +++ b/Tests/NetworkingTests/v2/APIServiceTests.swift @@ -41,7 +41,7 @@ final class APIServiceTests: XCTestCase { cachePolicy: .reloadIgnoringLocalAndRemoteCacheData, responseConstraints: [APIResponseConstraints.allowHTTPNotModified, APIResponseConstraints.requireETagHeader], - allowedQueryReservedCharacters: CharacterSet(charactersIn: ","))! + allowedQueryReservedCharacters: CharacterSet(charactersIn: ",")) let apiService = DefaultAPIService() let response = try await apiService.fetch(request: request) let responseHTML: String = try response.decodeBody() @@ -50,7 +50,7 @@ final class APIServiceTests: XCTestCase { func disabled_testRealCallJSON() async throws { // func testRealCallJSON() async throws { - let request = APIRequestV2(url: HTTPURLResponse.testUrl)! + let request = APIRequestV2(url: HTTPURLResponse.testUrl) let apiService = DefaultAPIService() let result = try await apiService.fetch(request: request) @@ -63,7 +63,7 @@ final class APIServiceTests: XCTestCase { func disabled_testRealCallString() async throws { // func testRealCallString() async throws { - let request = APIRequestV2(url: HTTPURLResponse.testUrl)! + let request = APIRequestV2(url: HTTPURLResponse.testUrl) let apiService = DefaultAPIService() let result = try await apiService.fetch(request: request) @@ -75,17 +75,16 @@ final class APIServiceTests: XCTestCase { "qName2": "qValue2"] MockURLProtocol.requestHandler = { request in let urlComponents = URLComponents(string: request.url!.absoluteString)! - XCTAssertTrue(urlComponents.queryItems!.contains(qItems.toURLQueryItems())) + XCTAssertTrue(urlComponents.queryItems!.contains(qItems.map { URLQueryItem(name: $0.key, value: $0.value) })) return (HTTPURLResponse.ok, nil) } - let request = APIRequestV2(url: HTTPURLResponse.testUrl, - queryItems: qItems)! + let request = APIRequestV2(url: HTTPURLResponse.testUrl, queryItems: qItems) let apiService = DefaultAPIService(urlSession: mockURLSession) _ = try await apiService.fetch(request: request) } func testURLRequestError() async throws { - let request = APIRequestV2(url: HTTPURLResponse.testUrl)! + let request = APIRequestV2(url: HTTPURLResponse.testUrl) enum TestError: Error { case anError @@ -111,7 +110,7 @@ final class APIServiceTests: XCTestCase { func testResponseRequirementAllowHTTPNotModifiedSuccess() async throws { let requirements = [APIResponseConstraints.allowHTTPNotModified ] - let request = APIRequestV2(url: HTTPURLResponse.testUrl, responseConstraints: requirements)! + let request = APIRequestV2(url: HTTPURLResponse.testUrl, responseConstraints: requirements) MockURLProtocol.requestHandler = { _ in ( HTTPURLResponse.notModified, Data()) } @@ -122,7 +121,7 @@ final class APIServiceTests: XCTestCase { } func testResponseRequirementAllowHTTPNotModifiedFailure() async throws { - let request = APIRequestV2(url: HTTPURLResponse.testUrl)! + let request = APIRequestV2(url: HTTPURLResponse.testUrl) MockURLProtocol.requestHandler = { _ in ( HTTPURLResponse.notModified, Data()) } @@ -147,7 +146,7 @@ final class APIServiceTests: XCTestCase { let requirements: [APIResponseConstraints] = [ APIResponseConstraints.requireETagHeader ] - let request = APIRequestV2(url: HTTPURLResponse.testUrl, responseConstraints: requirements)! + let request = APIRequestV2(url: HTTPURLResponse.testUrl, responseConstraints: requirements) MockURLProtocol.requestHandler = { _ in ( HTTPURLResponse.ok, nil) } // HTTPURLResponse.ok contains etag let apiService = DefaultAPIService(urlSession: mockURLSession) @@ -158,7 +157,7 @@ final class APIServiceTests: XCTestCase { func testResponseRequirementRequireETagHeaderFailure() async throws { let requirements = [ APIResponseConstraints.requireETagHeader ] - let request = APIRequestV2(url: HTTPURLResponse.testUrl, responseConstraints: requirements)! + let request = APIRequestV2(url: HTTPURLResponse.testUrl, responseConstraints: requirements) MockURLProtocol.requestHandler = { _ in ( HTTPURLResponse.okNoEtag, nil) } @@ -181,7 +180,7 @@ final class APIServiceTests: XCTestCase { func testResponseRequirementRequireUserAgentSuccess() async throws { let requirements = [ APIResponseConstraints.requireUserAgent ] - let request = APIRequestV2(url: HTTPURLResponse.testUrl, responseConstraints: requirements)! + let request = APIRequestV2(url: HTTPURLResponse.testUrl, responseConstraints: requirements) MockURLProtocol.requestHandler = { _ in ( HTTPURLResponse.okUserAgent, nil) @@ -194,7 +193,7 @@ final class APIServiceTests: XCTestCase { func testResponseRequirementRequireUserAgentFailure() async throws { let requirements = [ APIResponseConstraints.requireUserAgent ] - let request = APIRequestV2(url: HTTPURLResponse.testUrl, responseConstraints: requirements)! + let request = APIRequestV2(url: HTTPURLResponse.testUrl, responseConstraints: requirements) MockURLProtocol.requestHandler = { _ in ( HTTPURLResponse.ok, nil) } diff --git a/Tests/OnboardingTests/OnboardingSuggestionsViewModelsTests.swift b/Tests/OnboardingTests/OnboardingSuggestionsViewModelsTests.swift index ed7927f4f..942543505 100644 --- a/Tests/OnboardingTests/OnboardingSuggestionsViewModelsTests.swift +++ b/Tests/OnboardingTests/OnboardingSuggestionsViewModelsTests.swift @@ -147,11 +147,11 @@ class CapturingOnboardingNavigationDelegate: OnboardingNavigationDelegate { var suggestedSearchQuery: String? var urlToNavigateTo: URL? - func searchFor(_ query: String) { + func searchFromOnboarding(for query: String) { suggestedSearchQuery = query } - func navigateTo(url: URL) { + func navigateFromOnboarding(to url: URL) { urlToNavigateTo = url } } diff --git a/Tests/PhishingDetectionTests/BackgroundActivitySchedulerTests.swift b/Tests/PhishingDetectionTests/BackgroundActivitySchedulerTests.swift deleted file mode 100644 index 8d907efbd..000000000 --- a/Tests/PhishingDetectionTests/BackgroundActivitySchedulerTests.swift +++ /dev/null @@ -1,57 +0,0 @@ -// -// BackgroundActivitySchedulerTests.swift -// -// Copyright © 2024 DuckDuckGo. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// -import Foundation -import XCTest -@testable import PhishingDetection - -class BackgroundActivitySchedulerTests: XCTestCase { - var scheduler: BackgroundActivityScheduler! - var activityWasRun = false - - override func tearDown() { - scheduler = nil - super.tearDown() - } - - func testStart() async throws { - let expectation = self.expectation(description: "Activity should run") - scheduler = BackgroundActivityScheduler(interval: 1, identifier: "test") { - if !self.activityWasRun { - self.activityWasRun = true - expectation.fulfill() - } - } - await scheduler.start() - await fulfillment(of: [expectation], timeout: 2) - XCTAssertTrue(activityWasRun) - } - - func testRepeats() async throws { - let expectation = self.expectation(description: "Activity should repeat") - var runCount = 0 - scheduler = BackgroundActivityScheduler(interval: 1, identifier: "test") { - runCount += 1 - if runCount == 2 { - expectation.fulfill() - } - } - await scheduler.start() - await fulfillment(of: [expectation], timeout: 3) - XCTAssertEqual(runCount, 2) - } -} diff --git a/Tests/PhishingDetectionTests/Mocks/PhishingDetectionClientMock.swift b/Tests/PhishingDetectionTests/Mocks/PhishingDetectionClientMock.swift deleted file mode 100644 index 9f39598b2..000000000 --- a/Tests/PhishingDetectionTests/Mocks/PhishingDetectionClientMock.swift +++ /dev/null @@ -1,84 +0,0 @@ -// -// PhishingDetectionClientMock.swift -// -// Copyright © 2024 DuckDuckGo. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// - -import Foundation -import PhishingDetection - -public class MockPhishingDetectionClient: PhishingDetectionClientProtocol { - public var updateHashPrefixesWasCalled: Bool = false - public var updateFilterSetsWasCalled: Bool = false - - private var filterRevisions: [Int: FilterSetResponse] = [ - 0: FilterSetResponse(insert: [ - Filter(hashValue: "testhash1", regex: ".*example.*"), - Filter(hashValue: "testhash2", regex: ".*test.*") - ], delete: [], revision: 0, replace: true), - 1: FilterSetResponse(insert: [ - Filter(hashValue: "testhash3", regex: ".*test.*") - ], delete: [ - Filter(hashValue: "testhash1", regex: ".*example.*"), - ], revision: 1, replace: false), - 2: FilterSetResponse(insert: [ - Filter(hashValue: "testhash4", regex: ".*test.*") - ], delete: [ - Filter(hashValue: "testhash2", regex: ".*test.*"), - ], revision: 2, replace: false), - 4: FilterSetResponse(insert: [ - Filter(hashValue: "testhash5", regex: ".*test.*") - ], delete: [ - Filter(hashValue: "testhash3", regex: ".*test.*"), - ], revision: 4, replace: false) - ] - - private var hashPrefixRevisions: [Int: HashPrefixResponse] = [ - 0: HashPrefixResponse(insert: [ - "aa00bb11", - "bb00cc11", - "cc00dd11", - "dd00ee11", - "a379a6f6" - ], delete: [], revision: 0, replace: true), - 1: HashPrefixResponse(insert: ["93e2435e"], delete: [ - "cc00dd11", - "dd00ee11", - ], revision: 1, replace: false), - 2: HashPrefixResponse(insert: ["c0be0d0a6"], delete: [ - "bb00cc11", - ], revision: 2, replace: false), - 4: HashPrefixResponse(insert: ["a379a6f6"], delete: [ - "aa00bb11", - ], revision: 4, replace: false) - ] - - public func getFilterSet(revision: Int) async -> FilterSetResponse { - updateFilterSetsWasCalled = true - return filterRevisions[revision] ?? FilterSetResponse(insert: [], delete: [], revision: revision, replace: false) - } - - public func getHashPrefixes(revision: Int) async -> HashPrefixResponse { - updateHashPrefixesWasCalled = true - return hashPrefixRevisions[revision] ?? HashPrefixResponse(insert: [], delete: [], revision: revision, replace: false) - } - - public func getMatches(hashPrefix: String) async -> [Match] { - return [ - Match(hostname: "example.com", url: "https://example.com/mal", regex: ".", hash: "a379a6f6eeafb9a55e378c118034e2751e682fab9f2d30ab13d2125586ce1947"), - Match(hostname: "test.com", url: "https://test.com/mal", regex: ".*test.*", hash: "aa00bb11aa00cc11bb00cc11") - ] - } -} diff --git a/Tests/PhishingDetectionTests/Mocks/PhishingDetectionDataProviderMock.swift b/Tests/PhishingDetectionTests/Mocks/PhishingDetectionDataProviderMock.swift deleted file mode 100644 index 79d4d5d6b..000000000 --- a/Tests/PhishingDetectionTests/Mocks/PhishingDetectionDataProviderMock.swift +++ /dev/null @@ -1,47 +0,0 @@ -// -// PhishingDetectionDataProviderMock.swift -// -// Copyright © 2024 DuckDuckGo. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// - -import Foundation -import PhishingDetection - -public class MockPhishingDetectionDataProvider: PhishingDetectionDataProviding { - public var embeddedRevision: Int = 65 - var loadHashPrefixesCalled: Bool = false - var loadFilterSetCalled: Bool = true - var hashPrefixes: Set = ["aabb"] - var filterSet: Set = [Filter(hashValue: "dummyhash", regex: "dummyregex")] - - public func shouldReturnFilterSet(set: Set) { - self.filterSet = set - } - - public func shouldReturnHashPrefixes(set: Set) { - self.hashPrefixes = set - } - - public func loadEmbeddedFilterSet() -> Set { - self.loadHashPrefixesCalled = true - return self.filterSet - } - - public func loadEmbeddedHashPrefixes() -> Set { - self.loadFilterSetCalled = true - return self.hashPrefixes - } - -} diff --git a/Tests/PhishingDetectionTests/Mocks/PhishingDetectionDataStoreMock.swift b/Tests/PhishingDetectionTests/Mocks/PhishingDetectionDataStoreMock.swift deleted file mode 100644 index 54521419c..000000000 --- a/Tests/PhishingDetectionTests/Mocks/PhishingDetectionDataStoreMock.swift +++ /dev/null @@ -1,44 +0,0 @@ -// -// PhishingDetectionDataStoreMock.swift -// -// Copyright © 2024 DuckDuckGo. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// - -import Foundation -import PhishingDetection - -public class MockPhishingDetectionDataStore: PhishingDetectionDataSaving { - public var filterSet: Set - public var hashPrefixes: Set - public var currentRevision: Int - - public init() { - filterSet = Set() - hashPrefixes = Set() - currentRevision = 0 - } - - public func saveFilterSet(set: Set) { - filterSet = set - } - - public func saveHashPrefixes(set: Set) { - hashPrefixes = set - } - - public func saveRevision(_ revision: Int) { - currentRevision = revision - } -} diff --git a/Tests/PhishingDetectionTests/PhishingDetectionClientTests.swift b/Tests/PhishingDetectionTests/PhishingDetectionClientTests.swift deleted file mode 100644 index 6826c86d6..000000000 --- a/Tests/PhishingDetectionTests/PhishingDetectionClientTests.swift +++ /dev/null @@ -1,125 +0,0 @@ -// -// PhishingDetectionClientTests.swift -// -// Copyright © 2024 DuckDuckGo. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// -import Foundation -import XCTest -@testable import PhishingDetection - -final class PhishingDetectionAPIClientTests: XCTestCase { - - var mockSession: MockURLSession! - var client: PhishingDetectionAPIClient! - - override func setUp() { - super.setUp() - mockSession = MockURLSession() - client = PhishingDetectionAPIClient(environment: .staging, session: mockSession) - } - - override func tearDown() { - mockSession = nil - client = nil - super.tearDown() - } - - func testGetFilterSetSuccess() async { - // Given - let insertFilter = Filter(hashValue: "a379a6f6eeafb9a55e378c118034e2751e682fab9f2d30ab13d2125586ce1947", regex: ".") - let deleteFilter = Filter(hashValue: "6a929cd0b3ba4677eaedf1b2bdaf3ff89281cca94f688c83103bc9a676aea46d", regex: "(?i)^https?\\:\\/\\/[\\w\\-\\.]+(?:\\:(?:80|443))?") - let expectedResponse = FilterSetResponse(insert: [insertFilter], delete: [deleteFilter], revision: 1, replace: false) - mockSession.data = try? JSONEncoder().encode(expectedResponse) - mockSession.response = HTTPURLResponse(url: client.filterSetURL, statusCode: 200, httpVersion: nil, headerFields: nil) - - // When - let response = await client.getFilterSet(revision: 1) - - // Then - XCTAssertEqual(response, expectedResponse) - } - - func testGetHashPrefixesSuccess() async { - // Given - let expectedResponse = HashPrefixResponse(insert: ["abc"], delete: ["def"], revision: 1, replace: false) - mockSession.data = try? JSONEncoder().encode(expectedResponse) - mockSession.response = HTTPURLResponse(url: client.hashPrefixURL, statusCode: 200, httpVersion: nil, headerFields: nil) - - // When - let response = await client.getHashPrefixes(revision: 1) - - // Then - XCTAssertEqual(response, expectedResponse) - } - - func testGetMatchesSuccess() async { - // Given - let expectedResponse = MatchResponse(matches: [Match(hostname: "example.com", url: "https://example.com/test", regex: ".", hash: "a379a6f6eeafb9a55e378c118034e2751e682fab9f2d30ab13d2125586ce1947")]) - mockSession.data = try? JSONEncoder().encode(expectedResponse) - mockSession.response = HTTPURLResponse(url: client.matchesURL, statusCode: 200, httpVersion: nil, headerFields: nil) - - // When - let response = await client.getMatches(hashPrefix: "abc") - - // Then - XCTAssertEqual(response, expectedResponse.matches) - } - - func testGetFilterSetInvalidURL() async { - // Given - let invalidRevision = -1 - - // When - let response = await client.getFilterSet(revision: invalidRevision) - - // Then - XCTAssertEqual(response, FilterSetResponse(insert: [], delete: [], revision: invalidRevision, replace: false)) - } - - func testGetHashPrefixesInvalidURL() async { - // Given - let invalidRevision = -1 - - // When - let response = await client.getHashPrefixes(revision: invalidRevision) - - // Then - XCTAssertEqual(response, HashPrefixResponse(insert: [], delete: [], revision: invalidRevision, replace: false)) - } - - func testGetMatchesInvalidURL() async { - // Given - let invalidHashPrefix = "" - - // When - let response = await client.getMatches(hashPrefix: invalidHashPrefix) - - // Then - XCTAssertTrue(response.isEmpty) - } -} - -class MockURLSession: URLSessionProtocol { - var data: Data? - var response: URLResponse? - var error: Error? - - func data(for request: URLRequest) async throws -> (Data, URLResponse) { - if let error = error { - throw error - } - return (data ?? Data(), response ?? URLResponse()) - } -} diff --git a/Tests/PhishingDetectionTests/PhishingDetectionDataActivitiesTests.swift b/Tests/PhishingDetectionTests/PhishingDetectionDataActivitiesTests.swift deleted file mode 100644 index 583f94789..000000000 --- a/Tests/PhishingDetectionTests/PhishingDetectionDataActivitiesTests.swift +++ /dev/null @@ -1,48 +0,0 @@ -// -// PhishingDetectionDataActivitiesTests.swift -// -// Copyright © 2024 DuckDuckGo. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// - -import Foundation -import XCTest -@testable import PhishingDetection - -class PhishingDetectionDataActivitiesTests: XCTestCase { - var mockUpdateManager: MockPhishingDetectionUpdateManager! - var activities: PhishingDetectionDataActivities! - - override func setUp() { - super.setUp() - mockUpdateManager = MockPhishingDetectionUpdateManager() - activities = PhishingDetectionDataActivities(hashPrefixInterval: 1, filterSetInterval: 1, phishingDetectionDataProvider: MockPhishingDetectionDataProvider(), updateManager: mockUpdateManager) - } - - func testUpdateHashPrefixesAndFilterSetRuns() async { - let expectation = XCTestExpectation(description: "updateHashPrefixes and updateFilterSet completes") - - mockUpdateManager.completionHandler = { - expectation.fulfill() - } - - activities.start() - - await fulfillment(of: [expectation], timeout: 10.0) - - XCTAssertTrue(mockUpdateManager.didUpdateHashPrefixes) - XCTAssertTrue(mockUpdateManager.didUpdateFilterSet) - - } -} diff --git a/Tests/PhishingDetectionTests/PhishingDetectionDataProviderTest.swift b/Tests/PhishingDetectionTests/PhishingDetectionDataProviderTest.swift deleted file mode 100644 index 547f2dce8..000000000 --- a/Tests/PhishingDetectionTests/PhishingDetectionDataProviderTest.swift +++ /dev/null @@ -1,52 +0,0 @@ -// -// PhishingDetectionDataProviderTest.swift -// -// Copyright © 2024 DuckDuckGo. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// -import Foundation -import XCTest -@testable import PhishingDetection - -class PhishingDetectionDataProviderTest: XCTestCase { - var filterSetURL: URL! - var hashPrefixURL: URL! - var dataProvider: PhishingDetectionDataProvider! - - override func setUp() { - super.setUp() - filterSetURL = Bundle.module.url(forResource: "filterSet", withExtension: "json")! - hashPrefixURL = Bundle.module.url(forResource: "hashPrefixes", withExtension: "json")! - } - - override func tearDown() { - filterSetURL = nil - hashPrefixURL = nil - dataProvider = nil - super.tearDown() - } - - func testDataProviderLoadsJSON() { - dataProvider = PhishingDetectionDataProvider(revision: 0, filterSetURL: filterSetURL, filterSetDataSHA: "4fd2868a4f264501ec175ab866504a2a96c8d21a3b5195b405a4a83b51eae504", hashPrefixURL: hashPrefixURL, hashPrefixDataSHA: "21b047a9950fcaf86034a6b16181e18815cb8d276386d85c8977ca8c5f8aa05f") - let expectedFilter = Filter(hashValue: "e4753ddad954dafd4ff4ef67f82b3c1a2db6ef4a51bda43513260170e558bd13", regex: "(?i)^https?\\:\\/\\/privacy-test-pages\\.site(?:\\:(?:80|443))?\\/security\\/badware\\/phishing\\.html$") - XCTAssertTrue(dataProvider.loadEmbeddedFilterSet().contains(expectedFilter)) - XCTAssertTrue(dataProvider.loadEmbeddedHashPrefixes().contains("012db806")) - } - - func testReturnsNoneWhenSHAMismatch() { - dataProvider = PhishingDetectionDataProvider(revision: 0, filterSetURL: filterSetURL, filterSetDataSHA: "xx0", hashPrefixURL: hashPrefixURL, hashPrefixDataSHA: "00x") - XCTAssertTrue(dataProvider.loadEmbeddedFilterSet().isEmpty) - XCTAssertTrue(dataProvider.loadEmbeddedHashPrefixes().isEmpty) - } -} diff --git a/Tests/PhishingDetectionTests/PhishingDetectionDataStoreTests.swift b/Tests/PhishingDetectionTests/PhishingDetectionDataStoreTests.swift deleted file mode 100644 index 79e9fb500..000000000 --- a/Tests/PhishingDetectionTests/PhishingDetectionDataStoreTests.swift +++ /dev/null @@ -1,197 +0,0 @@ -// -// PhishingDetectionDataStoreTests.swift -// -// Copyright © 2024 DuckDuckGo. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// - -import Foundation - -import XCTest -@testable import PhishingDetection - -class PhishingDetectionDataStoreTests: XCTestCase { - var mockDataProvider: MockPhishingDetectionDataProvider! - let datasetFiles: [String] = ["hashPrefixes.json", "filterSet.json", "revision.txt"] - var dataStore: PhishingDetectionDataStore! - var fileStorageManager: FileStorageManager! - - override func setUp() { - super.setUp() - mockDataProvider = MockPhishingDetectionDataProvider() - fileStorageManager = MockPhishingFileStorageManager() - dataStore = PhishingDetectionDataStore(dataProvider: mockDataProvider, fileStorageManager: fileStorageManager) - } - - override func tearDown() { - mockDataProvider = nil - dataStore = nil - super.tearDown() - } - - func clearDatasets() { - for fileName in datasetFiles { - let emptyData = Data() - fileStorageManager.write(data: emptyData, to: fileName) - } - } - - func testWhenNoDataSavedThenProviderDataReturned() async { - clearDatasets() - let expectedFilerSet = Set([Filter(hashValue: "some", regex: "some")]) - let expectedHashPrefix = Set(["sassa"]) - mockDataProvider.shouldReturnFilterSet(set: expectedFilerSet) - mockDataProvider.shouldReturnHashPrefixes(set: expectedHashPrefix) - - let actualFilterSet = dataStore.filterSet - let actualHashPrefix = dataStore.hashPrefixes - - XCTAssertEqual(actualFilterSet, expectedFilerSet) - XCTAssertEqual(actualHashPrefix, expectedHashPrefix) - } - - func testWhenEmbeddedRevisionNewerThanOnDisk_ThenLoadEmbedded() async { - let encoder = JSONEncoder() - // On Disk Data Setup - fileStorageManager.write(data: "1".utf8data, to: "revision.txt") - let onDiskFilterSet = Set([Filter(hashValue: "other", regex: "other")]) - let filterSetData = try! encoder.encode(Array(onDiskFilterSet)) - let onDiskHashPrefix = Set(["faffa"]) - let hashPrefixData = try! encoder.encode(Array(onDiskHashPrefix)) - fileStorageManager.write(data: filterSetData, to: "filterSet.json") - fileStorageManager.write(data: hashPrefixData, to: "hashPrefixes.json") - - // Embedded Data Setup - mockDataProvider.embeddedRevision = 5 - let embeddedFilterSet = Set([Filter(hashValue: "some", regex: "some")]) - let embeddedHashPrefix = Set(["sassa"]) - mockDataProvider.shouldReturnFilterSet(set: embeddedFilterSet) - mockDataProvider.shouldReturnHashPrefixes(set: embeddedHashPrefix) - - let actualRevision = dataStore.currentRevision - let actualFilterSet = dataStore.filterSet - let actualHashPrefix = dataStore.hashPrefixes - - XCTAssertEqual(actualFilterSet, embeddedFilterSet) - XCTAssertEqual(actualHashPrefix, embeddedHashPrefix) - XCTAssertEqual(actualRevision, 5) - } - - func testWhenEmbeddedRevisionOlderThanOnDisk_ThenDontLoadEmbedded() async { - let encoder = JSONEncoder() - // On Disk Data Setup - fileStorageManager.write(data: "6".utf8data, to: "revision.txt") - let onDiskFilterSet = Set([Filter(hashValue: "other", regex: "other")]) - let filterSetData = try! encoder.encode(Array(onDiskFilterSet)) - let onDiskHashPrefix = Set(["faffa"]) - let hashPrefixData = try! encoder.encode(Array(onDiskHashPrefix)) - fileStorageManager.write(data: filterSetData, to: "filterSet.json") - fileStorageManager.write(data: hashPrefixData, to: "hashPrefixes.json") - - // Embedded Data Setup - mockDataProvider.embeddedRevision = 1 - let embeddedFilterSet = Set([Filter(hashValue: "some", regex: "some")]) - let embeddedHashPrefix = Set(["sassa"]) - mockDataProvider.shouldReturnFilterSet(set: embeddedFilterSet) - mockDataProvider.shouldReturnHashPrefixes(set: embeddedHashPrefix) - - let actualRevision = dataStore.currentRevision - let actualFilterSet = dataStore.filterSet - let actualHashPrefix = dataStore.hashPrefixes - - XCTAssertEqual(actualFilterSet, onDiskFilterSet) - XCTAssertEqual(actualHashPrefix, onDiskHashPrefix) - XCTAssertEqual(actualRevision, 6) - } - - func testWriteAndLoadData() async { - // Get and write data - let expectedHashPrefixes = Set(["aabb"]) - let expectedFilterSet = Set([Filter(hashValue: "dummyhash", regex: "dummyregex")]) - let expectedRevision = 65 - - dataStore.saveHashPrefixes(set: expectedHashPrefixes) - dataStore.saveFilterSet(set: expectedFilterSet) - dataStore.saveRevision(expectedRevision) - - XCTAssertEqual(dataStore.filterSet, expectedFilterSet) - XCTAssertEqual(dataStore.hashPrefixes, expectedHashPrefixes) - XCTAssertEqual(dataStore.currentRevision, expectedRevision) - - // Test decode JSON data to expected types - let storedHashPrefixesData = fileStorageManager.read(from: "hashPrefixes.json") - let storedFilterSetData = fileStorageManager.read(from: "filterSet.json") - let storedRevisionData = fileStorageManager.read(from: "revision.txt") - - let decoder = JSONDecoder() - if let storedHashPrefixes = try? decoder.decode(Set.self, from: storedHashPrefixesData!), - let storedFilterSet = try? decoder.decode(Set.self, from: storedFilterSetData!), - let storedRevisionString = String(data: storedRevisionData!, encoding: .utf8), - let storedRevision = Int(storedRevisionString.trimmingCharacters(in: .whitespacesAndNewlines)) { - - XCTAssertEqual(storedFilterSet, expectedFilterSet) - XCTAssertEqual(storedHashPrefixes, expectedHashPrefixes) - XCTAssertEqual(storedRevision, expectedRevision) - } else { - XCTFail("Failed to decode stored PhishingDetection data") - } - } - - func testLazyLoadingDoesNotReturnStaleData() async { - clearDatasets() - - // Set up initial data - let initialFilterSet = Set([Filter(hashValue: "initial", regex: "initial")]) - let initialHashPrefixes = Set(["initialPrefix"]) - mockDataProvider.shouldReturnFilterSet(set: initialFilterSet) - mockDataProvider.shouldReturnHashPrefixes(set: initialHashPrefixes) - - // Access the lazy-loaded properties to trigger loading - let loadedFilterSet = dataStore.filterSet - let loadedHashPrefixes = dataStore.hashPrefixes - - // Validate loaded data matches initial data - XCTAssertEqual(loadedFilterSet, initialFilterSet) - XCTAssertEqual(loadedHashPrefixes, initialHashPrefixes) - - // Update in-memory data - let updatedFilterSet = Set([Filter(hashValue: "updated", regex: "updated")]) - let updatedHashPrefixes = Set(["updatedPrefix"]) - dataStore.saveFilterSet(set: updatedFilterSet) - dataStore.saveHashPrefixes(set: updatedHashPrefixes) - - // Access lazy-loaded properties again - let reloadedFilterSet = dataStore.filterSet - let reloadedHashPrefixes = dataStore.hashPrefixes - - // Validate reloaded data matches updated data - XCTAssertEqual(reloadedFilterSet, updatedFilterSet) - XCTAssertEqual(reloadedHashPrefixes, updatedHashPrefixes) - - // Validate on-disk data is also updated - let storedFilterSetData = fileStorageManager.read(from: "filterSet.json") - let storedHashPrefixesData = fileStorageManager.read(from: "hashPrefixes.json") - - let decoder = JSONDecoder() - if let storedFilterSet = try? decoder.decode(Set.self, from: storedFilterSetData!), - let storedHashPrefixes = try? decoder.decode(Set.self, from: storedHashPrefixesData!) { - - XCTAssertEqual(storedFilterSet, updatedFilterSet) - XCTAssertEqual(storedHashPrefixes, updatedHashPrefixes) - } else { - XCTFail("Failed to decode stored PhishingDetection data after update") - } - } - -} diff --git a/Tests/PhishingDetectionTests/PhishingDetectionUpdateManagerTests.swift b/Tests/PhishingDetectionTests/PhishingDetectionUpdateManagerTests.swift deleted file mode 100644 index 6fec6c134..000000000 --- a/Tests/PhishingDetectionTests/PhishingDetectionUpdateManagerTests.swift +++ /dev/null @@ -1,155 +0,0 @@ -// -// PhishingDetectionUpdateManagerTests.swift -// -// Copyright © 2024 DuckDuckGo. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// - -import Foundation - -import XCTest -@testable import PhishingDetection - -class PhishingDetectionUpdateManagerTests: XCTestCase { - var updateManager: PhishingDetectionUpdateManager! - var dataStore: PhishingDetectionDataSaving! - var mockClient: MockPhishingDetectionClient! - - override func setUp() async throws { - try await super.setUp() - mockClient = MockPhishingDetectionClient() - dataStore = MockPhishingDetectionDataStore() - updateManager = PhishingDetectionUpdateManager(client: mockClient, dataStore: dataStore) - dataStore.saveRevision(0) - await updateManager.updateFilterSet() - await updateManager.updateHashPrefixes() - } - - override func tearDown() { - updateManager = nil - dataStore = nil - mockClient = nil - super.tearDown() - } - - func testUpdateHashPrefixes() async { - await updateManager.updateHashPrefixes() - XCTAssertFalse(dataStore.hashPrefixes.isEmpty, "Hash prefixes should not be empty after update.") - XCTAssertEqual(dataStore.hashPrefixes, [ - "aa00bb11", - "bb00cc11", - "cc00dd11", - "dd00ee11", - "a379a6f6" - ]) - } - - func testUpdateFilterSet() async { - await updateManager.updateFilterSet() - XCTAssertEqual(dataStore.filterSet, [ - Filter(hashValue: "testhash1", regex: ".*example.*"), - Filter(hashValue: "testhash2", regex: ".*test.*") - ]) - } - - func testRevision1AddsAndDeletesData() async { - let expectedFilterSet: Set = [ - Filter(hashValue: "testhash2", regex: ".*test.*"), - Filter(hashValue: "testhash3", regex: ".*test.*") - ] - let expectedHashPrefixes: Set = [ - "aa00bb11", - "bb00cc11", - "a379a6f6", - "93e2435e" - ] - - // Save revision and update the filter set and hash prefixes - dataStore.saveRevision(1) - await updateManager.updateFilterSet() - await updateManager.updateHashPrefixes() - - XCTAssertEqual(dataStore.filterSet, expectedFilterSet, "Filter set should match the expected set after update.") - XCTAssertEqual(dataStore.hashPrefixes, expectedHashPrefixes, "Hash prefixes should match the expected set after update.") - } - - func testRevision2AddsAndDeletesData() async { - let expectedFilterSet: Set = [ - Filter(hashValue: "testhash4", regex: ".*test.*"), - Filter(hashValue: "testhash1", regex: ".*example.*") - ] - let expectedHashPrefixes: Set = [ - "aa00bb11", - "a379a6f6", - "c0be0d0a6", - "dd00ee11", - "cc00dd11" - ] - - // Save revision and update the filter set and hash prefixes - dataStore.saveRevision(2) - await updateManager.updateFilterSet() - await updateManager.updateHashPrefixes() - - XCTAssertEqual(dataStore.filterSet, expectedFilterSet, "Filter set should match the expected set after update.") - XCTAssertEqual(dataStore.hashPrefixes, expectedHashPrefixes, "Hash prefixes should match the expected set after update.") - } - - func testRevision3AddsAndDeletesNothing() async { - let expectedFilterSet = dataStore.filterSet - let expectedHashPrefixes = dataStore.hashPrefixes - - // Save revision and update the filter set and hash prefixes - dataStore.saveRevision(3) - await updateManager.updateFilterSet() - await updateManager.updateHashPrefixes() - - XCTAssertEqual(dataStore.filterSet, expectedFilterSet, "Filter set should match the expected set after update.") - XCTAssertEqual(dataStore.hashPrefixes, expectedHashPrefixes, "Hash prefixes should match the expected set after update.") - } - - func testRevision4AddsAndDeletesData() async { - let expectedFilterSet: Set = [ - Filter(hashValue: "testhash2", regex: ".*test.*"), - Filter(hashValue: "testhash1", regex: ".*example.*"), - Filter(hashValue: "testhash5", regex: ".*test.*") - ] - let expectedHashPrefixes: Set = [ - "a379a6f6", - "dd00ee11", - "cc00dd11", - "bb00cc11" - ] - - // Save revision and update the filter set and hash prefixes - dataStore.saveRevision(4) - await updateManager.updateFilterSet() - await updateManager.updateHashPrefixes() - - XCTAssertEqual(dataStore.filterSet, expectedFilterSet, "Filter set should match the expected set after update.") - XCTAssertEqual(dataStore.hashPrefixes, expectedHashPrefixes, "Hash prefixes should match the expected set after update.") - } -} - -class MockPhishingFileStorageManager: FileStorageManager { - private var data: [String: Data] = [:] - - func write(data: Data, to filename: String) { - self.data[filename] = data - } - - func read(from filename: String) -> Data? { - return data[filename] - } -} diff --git a/Tests/PhishingDetectionTests/PhishingDetectorTests.swift b/Tests/PhishingDetectionTests/PhishingDetectorTests.swift deleted file mode 100644 index d2ef4a02e..000000000 --- a/Tests/PhishingDetectionTests/PhishingDetectorTests.swift +++ /dev/null @@ -1,104 +0,0 @@ -// -// PhishingDetectorTests.swift -// -// Copyright © 2024 DuckDuckGo. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// -import Foundation -import XCTest -@testable import PhishingDetection - -class IsMaliciousTests: XCTestCase { - - private var mockAPIClient: MockPhishingDetectionClient! - private var mockDataStore: MockPhishingDetectionDataStore! - private var mockEventMapping: MockEventMapping! - private var detector: PhishingDetector! - - override func setUp() { - super.setUp() - mockAPIClient = MockPhishingDetectionClient() - mockDataStore = MockPhishingDetectionDataStore() - mockEventMapping = MockEventMapping() - detector = PhishingDetector(apiClient: mockAPIClient, dataStore: mockDataStore, eventMapping: mockEventMapping) - } - - override func tearDown() { - mockAPIClient = nil - mockDataStore = nil - mockEventMapping = nil - detector = nil - super.tearDown() - } - - func testIsMaliciousWithLocalFilterHit() async { - let filter = Filter(hashValue: "255a8a793097aeea1f06a19c08cde28db0eb34c660c6e4e7480c9525d034b16d", regex: ".*malicious.*") - mockDataStore.filterSet = Set([filter]) - mockDataStore.hashPrefixes = Set(["255a8a79"]) - - let url = URL(string: "https://malicious.com/")! - - let result = await detector.isMalicious(url: url) - - XCTAssertTrue(result) - } - - func testIsMaliciousWithApiMatch() async { - mockDataStore.filterSet = Set() - mockDataStore.hashPrefixes = ["a379a6f6"] - - let url = URL(string: "https://example.com/mal")! - - let result = await detector.isMalicious(url: url) - - XCTAssertTrue(result) - } - - func testIsMaliciousWithHashPrefixMatch() async { - let filter = Filter(hashValue: "notamatch", regex: ".*malicious.*") - mockDataStore.filterSet = [filter] - mockDataStore.hashPrefixes = ["4c64eb24"] // matches safe.com - - let url = URL(string: "https://safe.com")! - - let result = await detector.isMalicious(url: url) - - XCTAssertFalse(result) - } - - func testIsMaliciousWithFullHashMatch() async { - // 4c64eb2468bcd3e113b37167e6b819aeccf550f974a6082ef17fb74ca68e823b - let filter = Filter(hashValue: "4c64eb2468bcd3e113b37167e6b819aeccf550f974a6082ef17fb74ca68e823b", regex: "https://safe.com/maliciousURI") - mockDataStore.filterSet = [filter] - mockDataStore.hashPrefixes = ["4c64eb24"] - - let url = URL(string: "https://safe.com")! - - let result = await detector.isMalicious(url: url) - - XCTAssertFalse(result) - } - - func testIsMaliciousWithNoHashPrefixMatch() async { - let filter = Filter(hashValue: "testHash", regex: ".*malicious.*") - mockDataStore.filterSet = [filter] - mockDataStore.hashPrefixes = ["testPrefix"] - - let url = URL(string: "https://safe.com")! - - let result = await detector.isMalicious(url: url) - - XCTAssertFalse(result) - } -} diff --git a/Tests/PrivacyDashboardTests/PrivacyDashboardControllerTests.swift b/Tests/PrivacyDashboardTests/PrivacyDashboardControllerTests.swift index 4c03464fa..867e7b888 100644 --- a/Tests/PrivacyDashboardTests/PrivacyDashboardControllerTests.swift +++ b/Tests/PrivacyDashboardTests/PrivacyDashboardControllerTests.swift @@ -260,13 +260,13 @@ final class PrivacyDashboardControllerTests: XCTestCase { func testWhenIsPhishingSetThenJavaScriptEvaluatedWithCorrectString() { let expectation = XCTestExpectation() - let privacyInfo = PrivacyInfo(url: URL(string: "someurl.com")!, parentEntity: nil, protectionStatus: .init(unprotectedTemporary: false, enabledFeatures: [], allowlisted: true, denylisted: true), isPhishing: false) + let privacyInfo = PrivacyInfo(url: URL(string: "someurl.com")!, parentEntity: nil, protectionStatus: .init(unprotectedTemporary: false, enabledFeatures: [], allowlisted: true, denylisted: true), malicousSiteThreatKind: .none) makePrivacyDashboardController(entryPoint: .dashboard, privacyInfo: privacyInfo) let config = WKWebViewConfiguration() let mockWebView = MockWebView(frame: .zero, configuration: config, expectation: expectation) privacyDashboardController.webView = mockWebView - privacyDashboardController.privacyInfo!.isPhishing = true + privacyDashboardController.privacyInfo!.malicousSiteThreatKind = .phishing wait(for: [expectation], timeout: 100) XCTAssertEqual(mockWebView.capturedJavaScriptString, "window.onChangePhishingStatus({\"phishingStatus\":true})") diff --git a/Tests/PrivacyStatsTests/CurrentPackTests.swift b/Tests/PrivacyStatsTests/CurrentPackTests.swift new file mode 100644 index 000000000..f22ed322d --- /dev/null +++ b/Tests/PrivacyStatsTests/CurrentPackTests.swift @@ -0,0 +1,121 @@ +// +// CurrentPackTests.swift +// +// Copyright © 2024 DuckDuckGo. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +import Combine +import XCTest +@testable import PrivacyStats + +final class CurrentPackTests: XCTestCase { + var currentPack: CurrentPack! + + override func setUp() async throws { + currentPack = CurrentPack(pack: .init(timestamp: Date.currentPrivacyStatsPackTimestamp), commitDebounce: 10_000_000) + } + + func testThatRecordBlockedTrackerUpdatesThePack() async { + await currentPack.recordBlockedTracker("A") + let companyA = await currentPack.pack.trackers["A"] + XCTAssertEqual(companyA, 1) + } + + func testThatRecordBlockedTrackerTriggersCommitChangesEvent() async throws { + let packs = try await waitForCommitChangesEvents(for: 100_000_000) { + await currentPack.recordBlockedTracker("A") + } + + let companyA = await currentPack.pack.trackers["A"] + XCTAssertEqual(companyA, 1) + XCTAssertEqual(packs.first?.trackers["A"], 1) + } + + func testThatMultipleCallsToRecordBlockedTrackerOnlyTriggerOneCommitChangesEvent() async throws { + let packs = try await waitForCommitChangesEvents(for: 1000_000_000) { + await currentPack.recordBlockedTracker("A") + await currentPack.recordBlockedTracker("A") + await currentPack.recordBlockedTracker("A") + await currentPack.recordBlockedTracker("A") + await currentPack.recordBlockedTracker("A") + } + + XCTAssertEqual(packs.count, 1) + XCTAssertEqual(packs.first?.trackers["A"], 5) + } + + func testThatRecordBlockedTrackerCalledConcurrentlyForTheSameCompanyStoresAllCalls() async { + await withTaskGroup(of: Void.self) { group in + (0..<1000).forEach { _ in + group.addTask { + await self.currentPack.recordBlockedTracker("A") + } + } + } + let companyA = await currentPack.pack.trackers["A"] + XCTAssertEqual(companyA, 1000) + } + + func testWhenCurrentPackIsOldThenRecordBlockedTrackerSendsCommitEventAndCreatesNewPack() async throws { + let oldTimestamp = Date.currentPrivacyStatsPackTimestamp.daysAgo(1) + let pack = PrivacyStatsPack( + timestamp: oldTimestamp, + trackers: ["A": 100, "B": 50, "C": 400] + ) + currentPack = CurrentPack(pack: pack, commitDebounce: 10_000_000) + + let packs = try await waitForCommitChangesEvents(for: 100_000_000) { + await currentPack.recordBlockedTracker("A") + } + + XCTAssertEqual(packs.count, 2) + let oldPack = try XCTUnwrap(packs.first) + XCTAssertEqual(oldPack, pack) + let newPack = try XCTUnwrap(packs.last) + XCTAssertEqual(newPack, PrivacyStatsPack(timestamp: Date.currentPrivacyStatsPackTimestamp, trackers: ["A": 1])) + } + + func testThatResetPackClearsAllRecordedTrackersAndSetsCurrentTimestamp() async { + let oldTimestamp = Date.currentPrivacyStatsPackTimestamp.daysAgo(1) + let pack = PrivacyStatsPack( + timestamp: oldTimestamp, + trackers: ["A": 100, "B": 50, "C": 400] + ) + currentPack = CurrentPack(pack: pack, commitDebounce: 10_000_000) + + await currentPack.resetPack() + + let packAfterReset = await currentPack.pack + XCTAssertEqual(packAfterReset, PrivacyStatsPack(timestamp: Date.currentPrivacyStatsPackTimestamp, trackers: [:])) + } + + // MARK: - Helpers + + /** + * Sets up Combine subscription, then calls the provided block and then waits + * for the specific time before cancelling the subscription. + * Returns an array of values passed in the published events. + */ + func waitForCommitChangesEvents(for nanoseconds: UInt64, _ block: () async -> Void) async throws -> [PrivacyStatsPack] { + var packs: [PrivacyStatsPack] = [] + let cancellable = currentPack.commitChangesPublisher.sink { packs.append($0) } + + await block() + + try await Task.sleep(nanoseconds: nanoseconds) + cancellable.cancel() + return packs + } +} diff --git a/Tests/PrivacyStatsTests/PrivacyStatsTests.swift b/Tests/PrivacyStatsTests/PrivacyStatsTests.swift new file mode 100644 index 000000000..fa05d8178 --- /dev/null +++ b/Tests/PrivacyStatsTests/PrivacyStatsTests.swift @@ -0,0 +1,317 @@ +// +// PrivacyStatsTests.swift +// +// Copyright © 2024 DuckDuckGo. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +import Combine +import Persistence +import TrackerRadarKit +import XCTest +@testable import PrivacyStats + +final class PrivacyStatsTests: XCTestCase { + var databaseProvider: TestPrivacyStatsDatabaseProvider! + var privacyStats: PrivacyStats! + + override func setUp() async throws { + databaseProvider = TestPrivacyStatsDatabaseProvider(databaseName: type(of: self).description()) + privacyStats = PrivacyStats(databaseProvider: databaseProvider) + } + + override func tearDown() async throws { + databaseProvider.tearDownDatabase() + } + + // MARK: - initializer + + func testThatOutdatedTrackerStatsAreDeletedUponInitialization() async throws { + try databaseProvider.addObjects { context in + let date = Date() + + return [ + DailyBlockedTrackersEntity.make(timestamp: date, companyName: "A", count: 1, context: context), + DailyBlockedTrackersEntity.make(timestamp: date.daysAgo(1), companyName: "A", count: 2, context: context), + DailyBlockedTrackersEntity.make(timestamp: date.daysAgo(6), companyName: "A", count: 7, context: context), + DailyBlockedTrackersEntity.make(timestamp: date.daysAgo(7), companyName: "A", count: 100, context: context), + DailyBlockedTrackersEntity.make(timestamp: date.daysAgo(8), companyName: "A", count: 100, context: context) + ] + } + + // recreate database provider with existing location so that the existing database is persisted in the initializer + databaseProvider = TestPrivacyStatsDatabaseProvider(databaseName: type(of: self).description(), location: databaseProvider.location) + privacyStats = PrivacyStats(databaseProvider: databaseProvider) + + let stats = await privacyStats.fetchPrivacyStats() + XCTAssertEqual(stats, ["A": 10]) + + let context = databaseProvider.database.makeContext(concurrencyType: .privateQueueConcurrencyType) + context.performAndWait { + do { + let allObjects = try context.fetch(DailyBlockedTrackersEntity.fetchRequest()) + XCTAssertEqual(Set(allObjects.map(\.count)), [1, 2, 7]) + } catch { + XCTFail("Context fetch should not fail") + } + } + } + + // MARK: - fetchPrivacyStats + + func testThatPrivacyStatsAreFetched() async throws { + let stats = await privacyStats.fetchPrivacyStats() + XCTAssertEqual(stats, [:]) + } + + func testThatFetchPrivacyStatsReturnsAllCompanies() async throws { + try databaseProvider.addObjects { context in + [ + DailyBlockedTrackersEntity.make(companyName: "A", count: 10, context: context), + DailyBlockedTrackersEntity.make(companyName: "B", count: 5, context: context), + DailyBlockedTrackersEntity.make(companyName: "C", count: 13, context: context), + DailyBlockedTrackersEntity.make(companyName: "D", count: 42, context: context) + ] + } + + let stats = await privacyStats.fetchPrivacyStats() + XCTAssertEqual(stats, ["A": 10, "B": 5, "C": 13, "D": 42]) + } + + func testThatFetchPrivacyStatsReturnsSumOfCompanyEntriesForPast7Days() async throws { + try databaseProvider.addObjects { context in + let date = Date() + + return [ + DailyBlockedTrackersEntity.make(timestamp: date, companyName: "A", count: 1, context: context), + DailyBlockedTrackersEntity.make(timestamp: date.daysAgo(1), companyName: "A", count: 2, context: context), + DailyBlockedTrackersEntity.make(timestamp: date.daysAgo(2), companyName: "A", count: 3, context: context), + DailyBlockedTrackersEntity.make(timestamp: date.daysAgo(3), companyName: "A", count: 4, context: context), + DailyBlockedTrackersEntity.make(timestamp: date.daysAgo(4), companyName: "A", count: 5, context: context), + DailyBlockedTrackersEntity.make(timestamp: date.daysAgo(5), companyName: "A", count: 6, context: context), + DailyBlockedTrackersEntity.make(timestamp: date.daysAgo(6), companyName: "A", count: 7, context: context) + ] + } + + let stats = await privacyStats.fetchPrivacyStats() + XCTAssertEqual(stats, ["A": 28]) + } + + func testThatFetchPrivacyStatsDiscardsEntriesOlderThan7Days() async throws { + try databaseProvider.addObjects { context in + let date = Date() + + return [ + DailyBlockedTrackersEntity.make(timestamp: date, companyName: "A", count: 1, context: context), + DailyBlockedTrackersEntity.make(timestamp: date.daysAgo(1), companyName: "A", count: 2, context: context), + DailyBlockedTrackersEntity.make(timestamp: date.daysAgo(7), companyName: "A", count: 10, context: context), + DailyBlockedTrackersEntity.make(timestamp: date.daysAgo(10), companyName: "A", count: 10, context: context), + DailyBlockedTrackersEntity.make(timestamp: date.daysAgo(20), companyName: "A", count: 10, context: context), + DailyBlockedTrackersEntity.make(timestamp: date.daysAgo(500), companyName: "A", count: 10, context: context), + ] + } + + let stats = await privacyStats.fetchPrivacyStats() + XCTAssertEqual(stats, ["A": 3]) + } + + // MARK: - recordBlockedTracker + + func testThatCallingRecordBlockedTrackerCausesDatabaseSaveAfterDelay() async throws { + await privacyStats.recordBlockedTracker("A") + + var stats = await privacyStats.fetchPrivacyStats() + XCTAssertEqual(stats, [:]) + + try await Task.sleep(nanoseconds: 1_500_000_000) + + stats = await privacyStats.fetchPrivacyStats() + XCTAssertEqual(stats, ["A": 1]) + } + + func testThatStatsUpdatePublisherIsCalledAfterDatabaseSave() async throws { + await privacyStats.recordBlockedTracker("A") + + await waitForStatsUpdateEvent() + + var stats = await privacyStats.fetchPrivacyStats() + XCTAssertEqual(stats, ["A": 1]) + + await privacyStats.recordBlockedTracker("B") + + await waitForStatsUpdateEvent() + + stats = await privacyStats.fetchPrivacyStats() + XCTAssertEqual(stats, ["A": 1, "B": 1]) + } + + func testWhenMultipleTrackersAreReportedInQuickSuccessionThenOnlyOneStatsUpdateEventIsReported() async throws { + await withTaskGroup(of: Void.self) { group in + (0..<5).forEach { _ in + group.addTask { + await self.privacyStats.recordBlockedTracker("A") + } + } + (0..<10).forEach { _ in + group.addTask { + await self.privacyStats.recordBlockedTracker("B") + } + } + (0..<3).forEach { _ in + group.addTask { + await self.privacyStats.recordBlockedTracker("C") + } + } + } + + // We have limited testing possibilities here, so let's just await the first stats update event + // and verify that all trackers are reported by privacy stats. + await waitForStatsUpdateEvent() + + let stats = await privacyStats.fetchPrivacyStats() + XCTAssertEqual(stats, ["A": 5, "B": 10, "C": 3]) + } + + func testThatCallingRecordBlockedTrackerWithNextDayTimestampCausesDeletingOldEntriesFromDatabase() async throws { + try databaseProvider.addObjects { context in + let date = Date() + return [ + DailyBlockedTrackersEntity.make(timestamp: date.daysAgo(1), companyName: "A", count: 2, context: context), + DailyBlockedTrackersEntity.make(timestamp: date.daysAgo(7), companyName: "A", count: 100, context: context), + ] + } + + // recreate database provider with existing location so that the existing database is persisted in the initializer + databaseProvider = TestPrivacyStatsDatabaseProvider(databaseName: type(of: self).description(), location: databaseProvider.location) + privacyStats = PrivacyStats(databaseProvider: databaseProvider) + + await privacyStats.recordBlockedTracker("A") + + // No waiting here because the first commit event will be sent immediately from the actor when pack's timestamp changes. + // We aren't testing the debounced commit in this test case. + + var stats = await privacyStats.fetchPrivacyStats() + XCTAssertEqual(stats, ["A": 2]) + + let context = databaseProvider.database.makeContext(concurrencyType: .privateQueueConcurrencyType) + context.performAndWait { + do { + let allObjects = try context.fetch(DailyBlockedTrackersEntity.fetchRequest()) + XCTAssertEqual(Set(allObjects.map(\.count)), [2]) + } catch { + XCTFail("Context fetch should not fail") + } + } + + await waitForStatsUpdateEvent() + stats = await privacyStats.fetchPrivacyStats() + XCTAssertEqual(stats, ["A": 3]) + } + + // MARK: - clearPrivacyStats + + func testThatClearPrivacyStatsTriggersUpdatesPublisher() async throws { + try await waitForStatsUpdateEvents(for: 1, count: 1) { + await privacyStats.clearPrivacyStats() + } + } + + func testWhenClearPrivacyStatsIsCalledThenFetchPrivacyStatsIsEmpty() async throws { + try databaseProvider.addObjects { context in + let date = Date() + + return [ + DailyBlockedTrackersEntity.make(timestamp: date, companyName: "A", count: 1, context: context), + DailyBlockedTrackersEntity.make(timestamp: date.daysAgo(1), companyName: "A", count: 2, context: context), + DailyBlockedTrackersEntity.make(timestamp: date.daysAgo(7), companyName: "A", count: 10, context: context), + DailyBlockedTrackersEntity.make(timestamp: date.daysAgo(10), companyName: "A", count: 10, context: context), + DailyBlockedTrackersEntity.make(timestamp: date.daysAgo(20), companyName: "A", count: 10, context: context), + DailyBlockedTrackersEntity.make(timestamp: date.daysAgo(500), companyName: "A", count: 10, context: context), + ] + } + + var stats = await privacyStats.fetchPrivacyStats() + XCTAssertFalse(stats.isEmpty) + + await privacyStats.clearPrivacyStats() + + stats = await privacyStats.fetchPrivacyStats() + XCTAssertTrue(stats.isEmpty) + + let context = databaseProvider.database.makeContext(concurrencyType: .privateQueueConcurrencyType) + context.performAndWait { + do { + let allObjects = try context.fetch(DailyBlockedTrackersEntity.fetchRequest()) + XCTAssertTrue(allObjects.isEmpty) + } catch { + XCTFail("fetch failed: \(error)") + } + } + } + + // MARK: - handleAppTermination + + func testThatHandleAppTerminationSavesCurrentPack() async throws { + let context = databaseProvider.database.makeContext(concurrencyType: .privateQueueConcurrencyType) + + context.performAndWait { + do { + let allObjects = try context.fetch(DailyBlockedTrackersEntity.fetchRequest()) + XCTAssertTrue(allObjects.isEmpty) + } catch { + XCTFail("fetch failed: \(error)") + } + } + await privacyStats.recordBlockedTracker("A") + await privacyStats.handleAppTermination() + + context.performAndWait { + do { + let allObjects = try context.fetch(DailyBlockedTrackersEntity.fetchRequest()) + XCTAssertEqual(allObjects.count, 1) + } catch { + XCTFail("fetch failed: \(error)") + } + } + + await waitForStatsUpdateEvent() + let stats = await privacyStats.fetchPrivacyStats() + XCTAssertEqual(stats, ["A": 1]) + } + + // MARK: - Helpers + + func waitForStatsUpdateEvent(file: StaticString = #file, line: UInt = #line) async { + let expectation = self.expectation(description: "statsUpdate") + let cancellable = privacyStats.statsUpdatePublisher.sink { expectation.fulfill() } + await fulfillment(of: [expectation], timeout: 2) + cancellable.cancel() + } + + /** + * Sets up an expectation with the fulfillment count specified by `count` parameter, + * then sets up Combine subscription, then calls the provided block and waits + * for time specified by `duration` before cancelling the subscription. + */ + func waitForStatsUpdateEvents(for duration: TimeInterval, count: Int, _ block: () async -> Void) async throws { + let expectation = self.expectation(description: "statsUpdate") + expectation.expectedFulfillmentCount = count + let cancellable = privacyStats.statsUpdatePublisher.sink { expectation.fulfill() } + + await block() + + await fulfillment(of: [expectation], timeout: duration) + cancellable.cancel() + } +} diff --git a/Tests/PrivacyStatsTests/PrivacyStatsUtilsTests.swift b/Tests/PrivacyStatsTests/PrivacyStatsUtilsTests.swift new file mode 100644 index 000000000..ec0e7e606 --- /dev/null +++ b/Tests/PrivacyStatsTests/PrivacyStatsUtilsTests.swift @@ -0,0 +1,360 @@ +// +// PrivacyStatsUtilsTests.swift +// +// Copyright © 2024 DuckDuckGo. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +import Persistence +import XCTest +@testable import PrivacyStats + +final class PrivacyStatsUtilsTests: XCTestCase { + var databaseProvider: TestPrivacyStatsDatabaseProvider! + var database: CoreDataDatabase! + + override func setUp() async throws { + databaseProvider = TestPrivacyStatsDatabaseProvider(databaseName: type(of: self).description()) + databaseProvider.initializeDatabase() + database = databaseProvider.database + } + + override func tearDown() async throws { + databaseProvider.tearDownDatabase() + } + + // MARK: - fetchOrInsertCurrentStats + + func testWhenThereAreNoObjectsForCompaniesThenFetchOrInsertCurrentStatsInsertsNewObjects() { + let context = database.makeContext(concurrencyType: .privateQueueConcurrencyType) + context.performAndWait { + let currentPackTimestamp = Date.currentPrivacyStatsPackTimestamp + let companyNames: Set = ["A", "B", "C", "D"] + + var returnedEntities: [DailyBlockedTrackersEntity] = [] + do { + returnedEntities = try PrivacyStatsUtils.fetchOrInsertCurrentStats(for: companyNames, in: context) + } catch { + XCTFail("Should not throw") + } + + let insertedEntities = context.insertedObjects.compactMap { $0 as? DailyBlockedTrackersEntity } + + XCTAssertEqual(returnedEntities.count, 4) + XCTAssertEqual(insertedEntities.count, 4) + XCTAssertEqual(Set(insertedEntities.map(\.companyName)), companyNames) + XCTAssertEqual(Set(insertedEntities.map(\.companyName)), Set(returnedEntities.map(\.companyName))) + + // All inserted entries have the same timestamp + XCTAssertEqual(Set(insertedEntities.map(\.timestamp)), [currentPackTimestamp]) + + // All inserted entries have the count of 0 + XCTAssertEqual(Set(insertedEntities.map(\.count)), [0]) + } + } + + func testWhenThereAreExistingObjectsForCompaniesThenFetchOrInsertCurrentStatsReturnsThem() throws { + let date = Date() + + try databaseProvider.addObjects { context in + return [ + DailyBlockedTrackersEntity.make(timestamp: date, companyName: "A", count: 123, context: context), + DailyBlockedTrackersEntity.make(timestamp: date, companyName: "B", count: 4567, context: context) + ] + } + + let context = database.makeContext(concurrencyType: .privateQueueConcurrencyType) + context.performAndWait { + let companyNames: Set = ["A", "B", "C", "D"] + + var returnedEntities: [DailyBlockedTrackersEntity] = [] + do { + returnedEntities = try PrivacyStatsUtils.fetchOrInsertCurrentStats(for: companyNames, in: context) + } catch { + XCTFail("Should not throw") + } + + let insertedEntities = context.insertedObjects.compactMap { $0 as? DailyBlockedTrackersEntity } + + XCTAssertEqual(returnedEntities.count, 4) + XCTAssertEqual(insertedEntities.count, 2) + XCTAssertEqual(Set(returnedEntities.map(\.companyName)), companyNames) + XCTAssertEqual(Set(insertedEntities.map(\.companyName)), ["C", "D"]) + + do { + let companyA = try XCTUnwrap(returnedEntities.first { $0.companyName == "A" }) + let companyB = try XCTUnwrap(returnedEntities.first { $0.companyName == "B" }) + + XCTAssertEqual(companyA.count, 123) + XCTAssertEqual(companyB.count, 4567) + } catch { + XCTFail("Should find companies A and B") + } + } + } + + // MARK: - loadCurrentDayStats + + func testWhenThereAreNoObjectsInDatabaseThenLoadCurrentDayStatsIsEmpty() throws { + let context = database.makeContext(concurrencyType: .privateQueueConcurrencyType) + + context.performAndWait { + do { + let currentDayStats = try PrivacyStatsUtils.loadCurrentDayStats(in: context) + XCTAssertTrue(currentDayStats.isEmpty) + } catch { + XCTFail("Should not throw") + } + } + } + + func testWhenThereAreObjectsInDatabaseForPreviousDaysThenLoadCurrentDayStatsIsEmpty() throws { + let date = Date() + + try databaseProvider.addObjects { context in + return [ + DailyBlockedTrackersEntity.make(timestamp: date.daysAgo(1), companyName: "A", count: 123, context: context), + DailyBlockedTrackersEntity.make(timestamp: date.daysAgo(2), companyName: "B", count: 4567, context: context), + DailyBlockedTrackersEntity.make(timestamp: date.daysAgo(5), companyName: "C", count: 890, context: context) + ] + } + + let context = database.makeContext(concurrencyType: .privateQueueConcurrencyType) + + context.performAndWait { + do { + let currentDayStats = try PrivacyStatsUtils.loadCurrentDayStats(in: context) + XCTAssertTrue(currentDayStats.isEmpty) + } catch { + XCTFail("Should not throw") + } + } + } + + func testThatObjectsWithZeroCountAreNotReportedByLoadCurrentDayStats() throws { + let date = Date() + + try databaseProvider.addObjects { context in + return [ + DailyBlockedTrackersEntity.make(timestamp: date, companyName: "A", count: 0, context: context), + DailyBlockedTrackersEntity.make(timestamp: date, companyName: "B", count: 0, context: context), + DailyBlockedTrackersEntity.make(timestamp: date, companyName: "C", count: 0, context: context) + ] + } + + let context = database.makeContext(concurrencyType: .privateQueueConcurrencyType) + + context.performAndWait { + do { + let currentDayStats = try PrivacyStatsUtils.loadCurrentDayStats(in: context) + XCTAssertTrue(currentDayStats.isEmpty) + } catch { + XCTFail("Should not throw") + } + } + } + + func testThatObjectsWithNonZeroCountAreReportedByLoadCurrentDayStats() throws { + let date = Date() + + try databaseProvider.addObjects { context in + return [ + DailyBlockedTrackersEntity.make(timestamp: date, companyName: "A", count: 150, context: context), + DailyBlockedTrackersEntity.make(timestamp: date, companyName: "B", count: 400, context: context), + DailyBlockedTrackersEntity.make(timestamp: date, companyName: "C", count: 84, context: context), + DailyBlockedTrackersEntity.make(timestamp: date, companyName: "D", count: 5, context: context) + ] + } + + let context = database.makeContext(concurrencyType: .privateQueueConcurrencyType) + + context.performAndWait { + do { + let currentDayStats = try PrivacyStatsUtils.loadCurrentDayStats(in: context) + XCTAssertEqual(currentDayStats, ["A": 150, "B": 400, "C": 84, "D": 5]) + } catch { + XCTFail("Should not throw") + } + } + } + + // MARK: - load7DayStats + + func testWhenThereAreNoObjectsInDatabaseThenLoad7DayStatsIsEmpty() throws { + let context = database.makeContext(concurrencyType: .privateQueueConcurrencyType) + + context.performAndWait { + do { + let stats = try PrivacyStatsUtils.load7DayStats(in: context) + XCTAssertTrue(stats.isEmpty) + } catch { + XCTFail("Should not throw") + } + } + } + + func testWhenThereAreObjectsInDatabaseFrom7DaysAgoOrMoreThenLoad7DayStatsIsEmpty() throws { + let date = Date() + + try databaseProvider.addObjects { context in + return [ + DailyBlockedTrackersEntity.make(timestamp: date.daysAgo(10), companyName: "A", count: 123, context: context), + DailyBlockedTrackersEntity.make(timestamp: date.daysAgo(20), companyName: "B", count: 4567, context: context), + DailyBlockedTrackersEntity.make(timestamp: date.daysAgo(7), companyName: "C", count: 890, context: context) + ] + } + + let context = database.makeContext(concurrencyType: .privateQueueConcurrencyType) + + context.performAndWait { + do { + let stats = try PrivacyStatsUtils.load7DayStats(in: context) + XCTAssertTrue(stats.isEmpty) + } catch { + XCTFail("Should not throw") + } + } + } + + func testThatObjectsWithZeroCountAreNotReportedByLoad7DayStats() throws { + let date = Date() + + try databaseProvider.addObjects { context in + return [ + DailyBlockedTrackersEntity.make(timestamp: date, companyName: "A", count: 0, context: context), + DailyBlockedTrackersEntity.make(timestamp: date.daysAgo(4), companyName: "B", count: 0, context: context), + DailyBlockedTrackersEntity.make(timestamp: date.daysAgo(6), companyName: "C", count: 0, context: context) + ] + } + + let context = database.makeContext(concurrencyType: .privateQueueConcurrencyType) + + context.performAndWait { + do { + let stats = try PrivacyStatsUtils.load7DayStats(in: context) + XCTAssertTrue(stats.isEmpty) + } catch { + XCTFail("Should not throw") + } + } + } + + func testThatObjectsWithNonZeroCountAreReportedByLoad7DayStats() throws { + let date = Date() + + try databaseProvider.addObjects { context in + return [ + DailyBlockedTrackersEntity.make(timestamp: date, companyName: "A", count: 150, context: context), + DailyBlockedTrackersEntity.make(timestamp: date.daysAgo(1), companyName: "B", count: 400, context: context), + DailyBlockedTrackersEntity.make(timestamp: date.daysAgo(2), companyName: "C", count: 84, context: context), + DailyBlockedTrackersEntity.make(timestamp: date.daysAgo(6), companyName: "D", count: 5, context: context) + ] + } + + let context = database.makeContext(concurrencyType: .privateQueueConcurrencyType) + + context.performAndWait { + do { + let stats = try PrivacyStatsUtils.load7DayStats(in: context) + XCTAssertEqual(stats, ["A": 150, "B": 400, "C": 84, "D": 5]) + } catch { + XCTFail("Should not throw") + } + } + } + + // MARK: - deleteOutdatedPacks + + func testWhenDeleteOutdatedPacksIsCalledThenObjectsFrom7DaysAgoOrMoreAreDeleted() throws { + let date = Date() + + try databaseProvider.addObjects { context in + return [ + DailyBlockedTrackersEntity.make(timestamp: date, companyName: "C", count: 1, context: context), + DailyBlockedTrackersEntity.make(timestamp: date.daysAgo(4), companyName: "C", count: 2, context: context), + DailyBlockedTrackersEntity.make(timestamp: date.daysAgo(6), companyName: "C", count: 3, context: context), + DailyBlockedTrackersEntity.make(timestamp: date.daysAgo(7), companyName: "C", count: 4, context: context), + DailyBlockedTrackersEntity.make(timestamp: date.daysAgo(8), companyName: "C", count: 5, context: context), + DailyBlockedTrackersEntity.make(timestamp: date.daysAgo(100), companyName: "C", count: 6, context: context) + ] + } + + let context = database.makeContext(concurrencyType: .privateQueueConcurrencyType) + + context.performAndWait { + do { + try PrivacyStatsUtils.deleteOutdatedPacks(in: context) + + let allObjects = try context.fetch(DailyBlockedTrackersEntity.fetchRequest()) + XCTAssertEqual(Set(allObjects.map(\.count)), [1, 2, 3]) + } catch { + XCTFail("Should not throw") + } + } + } + + func testWhenObjectsFrom7DaysAgoOrMoreAreNotPresentThenDeleteOutdatedPacksHasNoEffect() throws { + let date = Date() + + try databaseProvider.addObjects { context in + return [ + DailyBlockedTrackersEntity.make(timestamp: date, companyName: "C", count: 1, context: context), + DailyBlockedTrackersEntity.make(timestamp: date.daysAgo(4), companyName: "C", count: 2, context: context), + DailyBlockedTrackersEntity.make(timestamp: date.daysAgo(6), companyName: "C", count: 3, context: context) + ] + } + + let context = database.makeContext(concurrencyType: .privateQueueConcurrencyType) + + context.performAndWait { + do { + try PrivacyStatsUtils.deleteOutdatedPacks(in: context) + + let allObjects = try context.fetch(DailyBlockedTrackersEntity.fetchRequest()) + XCTAssertEqual(allObjects.count, 3) + } catch { + XCTFail("Should not throw") + } + } + } + + // MARK: - deleteAllStats + + func testThatDeleteAllStatsRemovesAllDatabaseObjects() throws { + let date = Date() + + try databaseProvider.addObjects { context in + return [ + DailyBlockedTrackersEntity.make(timestamp: date, companyName: "C", count: 1, context: context), + DailyBlockedTrackersEntity.make(timestamp: date.daysAgo(4), companyName: "C", count: 2, context: context), + DailyBlockedTrackersEntity.make(timestamp: date.daysAgo(6), companyName: "C", count: 3, context: context), + DailyBlockedTrackersEntity.make(timestamp: date.daysAgo(60), companyName: "C", count: 3, context: context), + DailyBlockedTrackersEntity.make(timestamp: date.daysAgo(600), companyName: "C", count: 3, context: context) + ] + } + + let context = database.makeContext(concurrencyType: .privateQueueConcurrencyType) + + context.performAndWait { + do { + try PrivacyStatsUtils.deleteAllStats(in: context) + + let allObjects = try context.fetch(DailyBlockedTrackersEntity.fetchRequest()) + XCTAssertTrue(allObjects.isEmpty) + } catch { + XCTFail("Should not throw") + } + } + } +} diff --git a/Tests/PrivacyStatsTests/TestPrivacyStatsDatabaseProvider.swift b/Tests/PrivacyStatsTests/TestPrivacyStatsDatabaseProvider.swift new file mode 100644 index 000000000..2cb210f0b --- /dev/null +++ b/Tests/PrivacyStatsTests/TestPrivacyStatsDatabaseProvider.swift @@ -0,0 +1,65 @@ +// +// TestPrivacyStatsDatabaseProvider.swift +// +// Copyright © 2024 DuckDuckGo. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +import Persistence +import XCTest +@testable import PrivacyStats + +final class TestPrivacyStatsDatabaseProvider: PrivacyStatsDatabaseProviding { + let databaseName: String + var database: CoreDataDatabase! + var location: URL! + + init(databaseName: String) { + self.databaseName = databaseName + } + + init(databaseName: String, location: URL) { + self.databaseName = databaseName + self.location = location + } + + @discardableResult + func initializeDatabase() -> CoreDataDatabase { + if location == nil { + location = FileManager.default.temporaryDirectory.appendingPathComponent(UUID().uuidString) + } + let model = CoreDataDatabase.loadModel(from: PrivacyStats.bundle, named: "PrivacyStats")! + database = CoreDataDatabase(name: databaseName, containerLocation: location, model: model) + database.loadStore() + return database + } + + func tearDownDatabase() { + try? database.tearDown(deleteStores: true) + database = nil + try? FileManager.default.removeItem(at: location) + } + + func addObjects(_ objects: (NSManagedObjectContext) -> [DailyBlockedTrackersEntity], file: StaticString = #file, line: UInt = #line) throws { + let context = database.makeContext(concurrencyType: .privateQueueConcurrencyType) + context.performAndWait { + _ = objects(context) + do { + try context.save() + } catch { + XCTFail("save failed: \(error)", file: file, line: line) + } + } + } +} diff --git a/Tests/RemoteMessagingTests/Mappers/DefaultRemoteMessagingSurveyURLBuilderTests.swift b/Tests/RemoteMessagingTests/Mappers/DefaultRemoteMessagingSurveyURLBuilderTests.swift new file mode 100644 index 000000000..c6a490eae --- /dev/null +++ b/Tests/RemoteMessagingTests/Mappers/DefaultRemoteMessagingSurveyURLBuilderTests.swift @@ -0,0 +1,127 @@ +// +// DefaultRemoteMessagingSurveyURLBuilderTests.swift +// +// Copyright © 2024 DuckDuckGo. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +import XCTest +import BrowserServicesKitTestsUtils +import RemoteMessagingTestsUtils +@testable import Subscription +@testable import RemoteMessaging + +class DefaultRemoteMessagingSurveyURLBuilderTests: XCTestCase { + + func testAddingATBParameter() { + let builder = buildRemoteMessagingSurveyURLBuilder(atb: "v456-7") + let baseURL = URL(string: "https://duckduckgo.com")! + let finalURL = builder.add(parameters: [.atb], to: baseURL) + + XCTAssertEqual(finalURL.absoluteString, "https://duckduckgo.com?atb=v456-7") + } + + func testAddingATBVariantParameter() { + let builder = buildRemoteMessagingSurveyURLBuilder(variant: "test-variant") + let baseURL = URL(string: "https://duckduckgo.com")! + let finalURL = builder.add(parameters: [.atbVariant], to: baseURL) + + XCTAssertEqual(finalURL.absoluteString, "https://duckduckgo.com?var=test-variant") + } + + func testAddingLocaleParameter() { + let builder = buildRemoteMessagingSurveyURLBuilder(locale: Locale(identifier: "en_NZ")) + let baseURL = URL(string: "https://duckduckgo.com")! + let finalURL = builder.add(parameters: [.locale], to: baseURL) + + XCTAssertEqual(finalURL.absoluteString, "https://duckduckgo.com?locale=en-NZ") + } + + func testAddingPrivacyProParameters() { + let builder = buildRemoteMessagingSurveyURLBuilder() + let baseURL = URL(string: "https://duckduckgo.com")! + let finalURL = builder.add(parameters: [.privacyProStatus, .privacyProPlatform, .privacyProPlatform], to: baseURL) + + XCTAssertEqual(finalURL.absoluteString, "https://duckduckgo.com?ppro_status=auto_renewable&ppro_platform=apple&ppro_platform=apple") + } + + func testAddingVPNUsageParameters() { + let builder = buildRemoteMessagingSurveyURLBuilder(vpnDaysSinceActivation: 10, vpnDaysSinceLastActive: 5) + let baseURL = URL(string: "https://duckduckgo.com")! + let finalURL = builder.add(parameters: [.vpnFirstUsed, .vpnLastUsed], to: baseURL) + + XCTAssertEqual(finalURL.absoluteString, "https://duckduckgo.com?vpn_first_used=10&vpn_last_used=5") + } + + func testAddingParametersToURLThatAlreadyHasThem() { + let builder = buildRemoteMessagingSurveyURLBuilder(vpnDaysSinceActivation: 10, vpnDaysSinceLastActive: 5) + let baseURL = URL(string: "https://duckduckgo.com?param=test")! + let finalURL = builder.add(parameters: [.vpnFirstUsed, .vpnLastUsed], to: baseURL) + + XCTAssertEqual(finalURL.absoluteString, "https://duckduckgo.com?param=test&vpn_first_used=10&vpn_last_used=5") + } + + private func buildRemoteMessagingSurveyURLBuilder( + atb: String = "v123-4", + variant: String = "var", + vpnDaysSinceActivation: Int = 2, + vpnDaysSinceLastActive: Int = 1, + locale: Locale = Locale(identifier: "en_US") + ) -> DefaultRemoteMessagingSurveyURLBuilder { + + let mockStatisticsStore = MockStatisticsStore() + mockStatisticsStore.atb = atb + mockStatisticsStore.variant = variant + + let vpnActivationDateStore = MockVPNActivationDateStore( + daysSinceActivation: vpnDaysSinceActivation, + daysSinceLastActive: vpnDaysSinceLastActive + ) + + let subscription = DDGSubscription(productId: "product-id", + name: "product-name", + billingPeriod: .monthly, + startedAt: Date(timeIntervalSince1970: 1000), + expiresOrRenewsAt: Date(timeIntervalSince1970: 2000), + platform: .apple, + status: .autoRenewable) + + return DefaultRemoteMessagingSurveyURLBuilder( + statisticsStore: mockStatisticsStore, + vpnActivationDateStore: vpnActivationDateStore, + subscription: subscription, + localeIdentifier: locale.identifier) + } + +} + +private class MockVPNActivationDateStore: VPNActivationDateProviding { + + var _daysSinceActivation: Int + var _daysSinceLastActive: Int + + init(daysSinceActivation: Int, daysSinceLastActive: Int) { + self._daysSinceActivation = daysSinceActivation + self._daysSinceLastActive = daysSinceLastActive + } + + func daysSinceActivation() -> Int? { + return _daysSinceActivation + } + + func daysSinceLastActive() -> Int? { + return _daysSinceLastActive + } + +} diff --git a/Tests/SpecialErrorPagesTests/SpecialErrorPagesTest.swift b/Tests/SpecialErrorPagesTests/SpecialErrorPagesTests.swift similarity index 96% rename from Tests/SpecialErrorPagesTests/SpecialErrorPagesTest.swift rename to Tests/SpecialErrorPagesTests/SpecialErrorPagesTests.swift index fa0ebf895..4eec5c739 100644 --- a/Tests/SpecialErrorPagesTests/SpecialErrorPagesTest.swift +++ b/Tests/SpecialErrorPagesTests/SpecialErrorPagesTests.swift @@ -1,5 +1,5 @@ // -// SpecialErrorPagesTest.swift +// SpecialErrorPagesTests.swift // // Copyright © 2023 DuckDuckGo. All rights reserved. // @@ -108,7 +108,7 @@ final class SpecialErrorPageUserScriptTests: XCTestCase { @MainActor func test_WhenHandlerForInitialSetUpCalled_AndIsEnabledTrue_ThenRightParameterReturned() async { // GIVEN - let expectedData = SpecialErrorData(kind: .ssl, errorType: "some error type", domain: "someDomain") + let expectedData = SpecialErrorData.ssl(type: .invalid, domain: "someDomain", eTldPlus1: nil) var encodable: Encodable? userScript.isEnabled = true delegate.errorData = expectedData @@ -191,11 +191,11 @@ class CapturingSpecialErrorPageUserScriptDelegate: SpecialErrorPageUserScriptDel var visitSiteCalled = false var advancedInfoPresentedCalled = false - func leaveSite() { + func leaveSiteAction() { leaveSiteCalled = true } - func visitSite() { + func visitSiteAction() { visitSiteCalled = true } diff --git a/Tests/SubscriptionTests/Flows/Models/SubscriptionOptionsTests.swift b/Tests/SubscriptionTests/Flows/Models/SubscriptionOptionsTests.swift index 5dfa14980..3fb695733 100644 --- a/Tests/SubscriptionTests/Flows/Models/SubscriptionOptionsTests.swift +++ b/Tests/SubscriptionTests/Flows/Models/SubscriptionOptionsTests.swift @@ -23,7 +23,7 @@ import SubscriptionTestingUtilities final class SubscriptionOptionsTests: XCTestCase { func testEncoding() throws { - let subscriptionOptions = SubscriptionOptions(platform: "macos", + let subscriptionOptions = SubscriptionOptions(platform: .macos, options: [ SubscriptionOption(id: "1", cost: SubscriptionOptionCost(displayPrice: "9 USD", recurrence: "monthly")), @@ -31,9 +31,9 @@ final class SubscriptionOptionsTests: XCTestCase { cost: SubscriptionOptionCost(displayPrice: "99 USD", recurrence: "yearly")) ], features: [ - SubscriptionFeature(name: "vpn"), - SubscriptionFeature(name: "personal-information-removal"), - SubscriptionFeature(name: "identity-theft-restoration") + SubscriptionFeature(name: .networkProtection), + SubscriptionFeature(name: .dataBrokerProtection), + SubscriptionFeature(name: .identityTheftRestoration) ]) let jsonEncoder = JSONEncoder() @@ -45,13 +45,13 @@ final class SubscriptionOptionsTests: XCTestCase { { "features" : [ { - "name" : "vpn" + "name" : "Network Protection" }, { - "name" : "personal-information-removal" + "name" : "Data Broker Protection" }, { - "name" : "identity-theft-restoration" + "name" : "Identity Theft Restoration" } ], "options" : [ @@ -87,12 +87,12 @@ final class SubscriptionOptionsTests: XCTestCase { } func testSubscriptionFeatureEncoding() throws { - let subscriptionFeature = SubscriptionFeature(name: "identity-theft-restoration") + let subscriptionFeature = SubscriptionFeature(name: .identityTheftRestoration) let data = try? JSONEncoder().encode(subscriptionFeature) let subscriptionFeatureString = String(data: data!, encoding: .utf8)! - XCTAssertEqual(subscriptionFeatureString, "{\"name\":\"identity-theft-restoration\"}") + XCTAssertEqual(subscriptionFeatureString, "{\"name\":\"Identity Theft Restoration\"}") } func testEmptySubscriptionOptions() throws { @@ -105,8 +105,8 @@ final class SubscriptionOptionsTests: XCTestCase { platform = .macos #endif - XCTAssertEqual(empty.platform, platform.rawValue) + XCTAssertEqual(empty.platform, platform) XCTAssertTrue(empty.options.isEmpty) - XCTAssertEqual(empty.features.count, SubscriptionFeatureName.allCases.count) + XCTAssertEqual(empty.features.count, 3) } } diff --git a/Tests/SubscriptionTests/Flows/StripePurchaseFlowTests.swift b/Tests/SubscriptionTests/Flows/StripePurchaseFlowTests.swift index 1752e3d15..e397805db 100644 --- a/Tests/SubscriptionTests/Flows/StripePurchaseFlowTests.swift +++ b/Tests/SubscriptionTests/Flows/StripePurchaseFlowTests.swift @@ -67,12 +67,14 @@ final class StripePurchaseFlowTests: XCTestCase { // Then switch result { case .success(let success): - XCTAssertEqual(success.platform, SubscriptionPlatformName.stripe.rawValue) + XCTAssertEqual(success.platform, SubscriptionPlatformName.stripe) XCTAssertEqual(success.options.count, SubscriptionMockFactory.productsItems.count) - XCTAssertEqual(success.features.count, SubscriptionFeatureName.allCases.count) + XCTAssertEqual(success.features.count, 3) + let allFeatures = [Entitlement.ProductName.networkProtection, Entitlement.ProductName.dataBrokerProtection, Entitlement.ProductName.identityTheftRestoration] let allNames = success.features.compactMap({ feature in feature.name}) - for name in SubscriptionFeatureName.allCases { - XCTAssertTrue(allNames.contains(name.rawValue)) + + for feature in allFeatures { + XCTAssertTrue(allNames.contains(feature)) } case .failure(let error): XCTFail("Unexpected failure: \(error)") diff --git a/Tests/SubscriptionTests/Managers/SubscriptionManagerTests.swift b/Tests/SubscriptionTests/Managers/SubscriptionManagerTests.swift index 9054194da..26ce6d89c 100644 --- a/Tests/SubscriptionTests/Managers/SubscriptionManagerTests.swift +++ b/Tests/SubscriptionTests/Managers/SubscriptionManagerTests.swift @@ -34,7 +34,9 @@ final class SubscriptionManagerTests: XCTestCase { var accountManager: AccountManagerMock! var subscriptionService: SubscriptionEndpointServiceMock! var authService: AuthEndpointServiceMock! + var subscriptionFeatureMappingCache: SubscriptionFeatureMappingCacheMock! var subscriptionEnvironment: SubscriptionEnvironment! + var subscriptionFeatureFlagger: FeatureFlaggerMapping! var subscriptionManager: SubscriptionManager! @@ -43,14 +45,18 @@ final class SubscriptionManagerTests: XCTestCase { accountManager = AccountManagerMock() subscriptionService = SubscriptionEndpointServiceMock() authService = AuthEndpointServiceMock() + subscriptionFeatureMappingCache = SubscriptionFeatureMappingCacheMock() subscriptionEnvironment = SubscriptionEnvironment(serviceEnvironment: .production, purchasePlatform: .appStore) + subscriptionFeatureFlagger = FeatureFlaggerMapping(mapping: { $0.defaultState }) subscriptionManager = DefaultSubscriptionManager(storePurchaseManager: storePurchaseManager, accountManager: accountManager, subscriptionEndpointService: subscriptionService, authEndpointService: authService, - subscriptionEnvironment: subscriptionEnvironment) + subscriptionFeatureMappingCache: subscriptionFeatureMappingCache, + subscriptionEnvironment: subscriptionEnvironment, + subscriptionFeatureFlagger: subscriptionFeatureFlagger) } @@ -202,7 +208,9 @@ final class SubscriptionManagerTests: XCTestCase { accountManager: accountManager, subscriptionEndpointService: subscriptionService, authEndpointService: authService, - subscriptionEnvironment: productionEnvironment) + subscriptionFeatureMappingCache: subscriptionFeatureMappingCache, + subscriptionEnvironment: productionEnvironment, + subscriptionFeatureFlagger: subscriptionFeatureFlagger) // When let productionPurchaseURL = productionSubscriptionManager.url(for: .purchase) @@ -219,7 +227,9 @@ final class SubscriptionManagerTests: XCTestCase { accountManager: accountManager, subscriptionEndpointService: subscriptionService, authEndpointService: authService, - subscriptionEnvironment: stagingEnvironment) + subscriptionFeatureMappingCache: subscriptionFeatureMappingCache, + subscriptionEnvironment: stagingEnvironment, + subscriptionFeatureFlagger: subscriptionFeatureFlagger) // When let stagingPurchaseURL = stagingSubscriptionManager.url(for: .purchase) diff --git a/Tests/SubscriptionTests/SubscriptionCookie/SubscriptionCookieManagerTests.swift b/Tests/SubscriptionTests/SubscriptionCookie/SubscriptionCookieManagerTests.swift index 07fb12bd4..2a6a9d3d8 100644 --- a/Tests/SubscriptionTests/SubscriptionCookie/SubscriptionCookieManagerTests.swift +++ b/Tests/SubscriptionTests/SubscriptionCookie/SubscriptionCookieManagerTests.swift @@ -33,6 +33,7 @@ final class SubscriptionCookieManagerTests: XCTestCase { var authService: AuthEndpointServiceMock! var storePurchaseManager: StorePurchaseManagerMock! var subscriptionEnvironment: SubscriptionEnvironment! + var subscriptionFeatureMappingCache: SubscriptionFeatureMappingCacheMock! var subscriptionManager: SubscriptionManagerMock! var cookieStore: HTTPCookieStore! @@ -45,13 +46,15 @@ final class SubscriptionCookieManagerTests: XCTestCase { storePurchaseManager = StorePurchaseManagerMock() subscriptionEnvironment = SubscriptionEnvironment(serviceEnvironment: .production, purchasePlatform: .appStore) + subscriptionFeatureMappingCache = SubscriptionFeatureMappingCacheMock() subscriptionManager = SubscriptionManagerMock(accountManager: accountManager, subscriptionEndpointService: subscriptionService, authEndpointService: authService, storePurchaseManager: storePurchaseManager, currentEnvironment: subscriptionEnvironment, - canPurchase: true) + canPurchase: true, + subscriptionFeatureMappingCache: subscriptionFeatureMappingCache) cookieStore = MockHTTPCookieStore() subscriptionCookieManager = SubscriptionCookieManager(subscriptionManager: subscriptionManager,