diff --git a/.swiftpm/xcode/xcshareddata/xcschemes/BrowserServicesKit-Package.xcscheme b/.swiftpm/xcode/xcshareddata/xcschemes/BrowserServicesKit-Package.xcscheme
index 56a2ef845..7aeef5267 100644
--- a/.swiftpm/xcode/xcshareddata/xcschemes/BrowserServicesKit-Package.xcscheme
+++ b/.swiftpm/xcode/xcshareddata/xcschemes/BrowserServicesKit-Package.xcscheme
@@ -491,9 +491,9 @@
buildForAnalyzing = "YES">
@@ -539,6 +539,20 @@
ReferencedContainer = "container:">
+
+
+
+
@@ -818,6 +832,16 @@
ReferencedContainer = "container:">
+
+
+
+
Bool {
Logger.contentBlocking.debug("Fetch last compiled rules: \(lastCompiledRules.count, privacy: .public)")
let initialCompilationTask = LastCompiledRulesLookupTask(sourceRules: rulesSource.contentBlockerRulesLists,
@@ -274,8 +281,10 @@ public class ContentBlockerRulesManager: CompiledRuleListsSource {
// We want to confine Compilation work to WorkQueue, so we wait to come back from async Task
mutex.wait()
- if let rules = initialCompilationTask.getFetchedRules() {
- applyRules(rules)
+ let rulesFound = initialCompilationTask.getFetchedRules()
+
+ if let rulesFound {
+ applyRules(rulesFound)
} else {
lock.lock()
state = .idle
@@ -284,6 +293,8 @@ public class ContentBlockerRulesManager: CompiledRuleListsSource {
// No matter if rules were found or not, we need to schedule recompilation, after all
scheduleCompilation()
+
+ return rulesFound != nil
}
private func prepareSourceManagers() {
diff --git a/Sources/BrowserServicesKit/ContentBlocking/UserScripts/SurrogatesUserScript.swift b/Sources/BrowserServicesKit/ContentBlocking/UserScripts/SurrogatesUserScript.swift
index ba10bf76f..0c83b7dd8 100644
--- a/Sources/BrowserServicesKit/ContentBlocking/UserScripts/SurrogatesUserScript.swift
+++ b/Sources/BrowserServicesKit/ContentBlocking/UserScripts/SurrogatesUserScript.swift
@@ -16,11 +16,11 @@
// limitations under the License.
//
-import WebKit
+import Common
+import ContentBlocking
import TrackerRadarKit
import UserScript
-import ContentBlocking
-import Common
+@preconcurrency import WebKit
public protocol SurrogatesUserScriptDelegate: NSObjectProtocol {
diff --git a/Sources/BrowserServicesKit/ContentScopeScript/SpecialPagesUserScript.swift b/Sources/BrowserServicesKit/ContentScopeScript/SpecialPagesUserScript.swift
index 215dbcc6f..f29a6e520 100644
--- a/Sources/BrowserServicesKit/ContentScopeScript/SpecialPagesUserScript.swift
+++ b/Sources/BrowserServicesKit/ContentScopeScript/SpecialPagesUserScript.swift
@@ -43,8 +43,7 @@ public final class SpecialPagesUserScript: NSObject, UserScript, UserScriptMessa
@available(macOS 11.0, iOS 14.0, *)
extension SpecialPagesUserScript: WKScriptMessageHandlerWithReply {
@MainActor
- public func userContentController(_ userContentController: WKUserContentController,
- didReceive message: WKScriptMessage) async -> (Any?, String?) {
+ public func userContentController(_ userContentController: WKUserContentController, didReceive message: WKScriptMessage) async -> (Any?, String?) {
let action = broker.messageHandlerFor(message)
do {
let json = try await broker.execute(action: action, original: message)
diff --git a/Sources/BrowserServicesKit/ContentScopeScript/UserContentController.swift b/Sources/BrowserServicesKit/ContentScopeScript/UserContentController.swift
index 9d534abe2..8c2ee2169 100644
--- a/Sources/BrowserServicesKit/ContentScopeScript/UserContentController.swift
+++ b/Sources/BrowserServicesKit/ContentScopeScript/UserContentController.swift
@@ -18,10 +18,10 @@
import Combine
import Common
-import UserScript
-import WebKit
-import QuartzCore
import os.log
+import QuartzCore
+import UserScript
+@preconcurrency import WebKit
public protocol UserContentControllerDelegate: AnyObject {
@MainActor
diff --git a/Sources/BrowserServicesKit/FeatureFlagger/ExperimentCohortsManager.swift b/Sources/BrowserServicesKit/FeatureFlagger/ExperimentCohortsManager.swift
new file mode 100644
index 000000000..60a03f8c4
--- /dev/null
+++ b/Sources/BrowserServicesKit/FeatureFlagger/ExperimentCohortsManager.swift
@@ -0,0 +1,144 @@
+//
+// ExperimentCohortsManager.swift
+//
+// Copyright © 2024 DuckDuckGo. All rights reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//
+
+import Foundation
+
+public typealias CohortID = String
+public typealias SubfeatureID = String
+public typealias ParentFeatureID = String
+public typealias Experiments = [String: ExperimentData]
+
+public struct ExperimentSubfeature {
+ let parentID: ParentFeatureID
+ let subfeatureID: SubfeatureID
+ let cohorts: [PrivacyConfigurationData.Cohort]
+}
+
+public struct ExperimentData: Codable, Equatable {
+ public let parentID: ParentFeatureID
+ public let cohortID: CohortID
+ public let enrollmentDate: Date
+}
+
+public protocol ExperimentCohortsManaging {
+ /// Retrieves all the experiments a user is enrolled in
+ var experiments: Experiments? { get }
+
+ /// Resolves the cohort for a given experiment subfeature.
+ ///
+ /// This method determines whether the user is currently assigned to a valid cohort
+ /// for the specified experiment. If the assigned cohort is valid (i.e., it matches
+ /// one of the experiment's defined cohorts), the method returns the assigned cohort.
+ /// Otherwise, the invalid cohort is removed, and a new cohort is assigned if
+ /// `allowCohortReassignment` is `true`.
+ ///
+ /// - Parameters:
+ /// - experiment: The `ExperimentSubfeature` representing the experiment and its associated cohorts.
+ /// - allowCohortReassignment: A Boolean value indicating whether cohort assignment is allowed
+ /// if the user is not already assigned to a valid cohort.
+ ///
+ /// - Returns: The valid `CohortID` assigned to the user for the experiment, or `nil`
+ /// if no valid cohort exists and `allowCohortReassignment` is `false`.
+ ///
+ /// - Behavior:
+ /// 1. Retrieves the currently assigned cohort for the experiment using the `subfeatureID`.
+ /// 2. Validates if the assigned cohort exists within the experiment's cohort list:
+ /// - If valid, the assigned cohort is returned.
+ /// - If invalid, the cohort is removed from storage.
+ /// 3. If cohort assignment is enabled (`allowCohortReassignment` is `true`), a new cohort
+ /// is assigned based on the experiment's cohort weights and saved in storage.
+ /// - Cohort assignment is probabilistic, determined by the cohort weights.
+ ///
+ func resolveCohort(for experiment: ExperimentSubfeature, allowCohortReassignment: Bool) -> CohortID?
+}
+
+public class ExperimentCohortsManager: ExperimentCohortsManaging {
+
+ private var store: ExperimentsDataStoring
+ private let randomizer: (Range) -> Double
+ private let queue = DispatchQueue(label: "com.ExperimentCohortsManager.queue")
+
+ public var experiments: Experiments? {
+ get {
+ queue.sync {
+ store.experiments
+ }
+ }
+ }
+
+ public init(store: ExperimentsDataStoring = ExperimentsDataStore(), randomizer: @escaping (Range) -> Double = Double.random(in:)) {
+ self.store = store
+ self.randomizer = randomizer
+ }
+
+ public func resolveCohort(for experiment: ExperimentSubfeature, allowCohortReassignment: Bool) -> CohortID? {
+ queue.sync {
+ let assignedCohort = cohort(for: experiment.subfeatureID)
+ if experiment.cohorts.contains(where: { $0.name == assignedCohort }) {
+ return assignedCohort
+ }
+ removeCohort(from: experiment.subfeatureID)
+ return allowCohortReassignment ? assignCohort(to: experiment) : nil
+ }
+ }
+}
+
+// MARK: Helper functions
+extension ExperimentCohortsManager {
+
+ private func assignCohort(to subfeature: ExperimentSubfeature) -> CohortID? {
+ let cohorts = subfeature.cohorts
+ let totalWeight = cohorts.map(\.weight).reduce(0, +)
+ guard totalWeight > 0 else { return nil }
+
+ let randomValue = randomizer(0.. CohortID? {
+ guard let experiments = store.experiments else { return nil }
+ return experiments[subfeatureID]?.cohortID
+ }
+
+ private func enrollmentDate(for subfeatureID: SubfeatureID) -> Date? {
+ guard let experiments = store.experiments else { return nil }
+ return experiments[subfeatureID]?.enrollmentDate
+ }
+
+ private func removeCohort(from subfeatureID: SubfeatureID) {
+ guard var experiments = store.experiments else { return }
+ experiments.removeValue(forKey: subfeatureID)
+ store.experiments = experiments
+ }
+
+ private func saveCohort(_ cohort: CohortID, in experimentID: SubfeatureID, parentID: ParentFeatureID) {
+ var experiments = store.experiments ?? Experiments()
+ let experimentData = ExperimentData(parentID: parentID, cohortID: cohort, enrollmentDate: Date())
+ experiments[experimentID] = experimentData
+ store.experiments = experiments
+ }
+}
diff --git a/Sources/BrowserServicesKit/PrivacyConfig/ExperimentsDataStore.swift b/Sources/BrowserServicesKit/FeatureFlagger/ExperimentsDataStore.swift
similarity index 86%
rename from Sources/BrowserServicesKit/PrivacyConfig/ExperimentsDataStore.swift
rename to Sources/BrowserServicesKit/FeatureFlagger/ExperimentsDataStore.swift
index bdf82819a..c99184df6 100644
--- a/Sources/BrowserServicesKit/PrivacyConfig/ExperimentsDataStore.swift
+++ b/Sources/BrowserServicesKit/FeatureFlagger/ExperimentsDataStore.swift
@@ -18,16 +18,16 @@
import Foundation
-protocol ExperimentsDataStoring {
+public protocol ExperimentsDataStoring {
var experiments: Experiments? { get set }
}
-protocol LocalDataStoring {
+public protocol LocalDataStoring {
func data(forKey defaultName: String) -> Data?
func set(_ value: Any?, forKey defaultName: String)
}
-struct ExperimentsDataStore: ExperimentsDataStoring {
+public struct ExperimentsDataStore: ExperimentsDataStoring {
private enum Constants {
static let experimentsDataKey = "ExperimentsData"
@@ -36,13 +36,13 @@ struct ExperimentsDataStore: ExperimentsDataStoring {
private let decoder = JSONDecoder()
private let encoder = JSONEncoder()
- init(localDataStoring: LocalDataStoring = UserDefaults.standard) {
+ public init(localDataStoring: LocalDataStoring = UserDefaults.standard) {
self.localDataStoring = localDataStoring
encoder.dateEncodingStrategy = .secondsSince1970
decoder.dateDecodingStrategy = .secondsSince1970
}
- var experiments: Experiments? {
+ public var experiments: Experiments? {
get {
guard let savedData = localDataStoring.data(forKey: Constants.experimentsDataKey) else { return nil }
return try? decoder.decode(Experiments.self, from: savedData)
diff --git a/Sources/BrowserServicesKit/FeatureFlagger/FeatureFlagger.swift b/Sources/BrowserServicesKit/FeatureFlagger/FeatureFlagger.swift
index d6f5fdee4..3ecb53d1e 100644
--- a/Sources/BrowserServicesKit/FeatureFlagger/FeatureFlagger.swift
+++ b/Sources/BrowserServicesKit/FeatureFlagger/FeatureFlagger.swift
@@ -18,6 +18,8 @@
import Foundation
+public protocol FlagCohort: RawRepresentable, CaseIterable where RawValue == CohortID {}
+
/// This protocol defines a common interface for feature flags managed by FeatureFlagger.
///
/// It should be implemented by the feature flag type in client apps.
@@ -53,7 +55,43 @@ public protocol FeatureFlagDescribing: CaseIterable {
/// case .sync:
/// return .disabled
/// case .cookieConsent:
- /// return .internalOnly
+ /// return .internalOnly()
+ /// case .credentialsAutofill:
+ /// return .remoteDevelopment(.subfeature(AutofillSubfeature.credentialsAutofill))
+ /// case .duckPlayer:
+ /// return .remoteReleasable(.feature(.duckPlayer))
+ /// }
+ /// }
+ /// ```
+ var source: FeatureFlagSource { get }
+}
+
+/// This protocol defines a common interface for experiment feature flags managed by FeatureFlagger.
+///
+/// It should be implemented by the feature flag type in client apps.
+///
+public protocol FeatureFlagExperimentDescribing {
+
+ /// Returns a string representation of the flag
+ var rawValue: String { get }
+
+ /// Defines the source of the experiment feature flag, which corresponds to
+ /// where the final flag value should come from.
+ ///
+ /// Example client implementation:
+ ///
+ /// ```
+ /// public enum FeatureFlag: FeatureFlagDescribing {
+ /// case sync
+ /// case autofill
+ /// case cookieConsent
+ /// case duckPlayer
+ ///
+ /// var source: FeatureFlagSource {
+ /// case .sync:
+ /// return .disabled
+ /// case .cookieConsent:
+ /// return .internalOnly(cohort)
/// case .credentialsAutofill:
/// return .remoteDevelopment(.subfeature(AutofillSubfeature.credentialsAutofill))
/// case .duckPlayer:
@@ -62,6 +100,29 @@ public protocol FeatureFlagDescribing: CaseIterable {
/// }
/// ```
var source: FeatureFlagSource { get }
+
+ /// Represents the possible groups or variants within an experiment.
+ ///
+ /// The `Cohort` type is used to define user groups or test variations for feature
+ /// experimentation. Each cohort typically corresponds to a specific behavior or configuration
+ /// applied to a subset of users. For example, in an A/B test, you might define cohorts such as
+ /// `control` and `treatment`.
+ ///
+ /// Each cohort must conform to the `CohortEnum` protocol, which ensures that the cohort type
+ /// is an `enum` with `String` raw values and provides access to all possible cases
+ /// through `CaseIterable`.
+ ///
+ /// Example:
+ /// ```
+ /// public enum AutofillCohorts: String, CohortEnum {
+ /// case control
+ /// case treatment
+ /// }
+ /// ```
+ ///
+ /// The `Cohort` type allows dynamic resolution of cohorts by their raw `String` value,
+ /// making it easy to map user configurations to specific cohort groups.
+ associatedtype CohortType: FlagCohort
}
public enum FeatureFlagSource {
@@ -69,7 +130,7 @@ public enum FeatureFlagSource {
case disabled
/// Enabled for internal users only. Cannot be toggled remotely
- case internalOnly
+ case internalOnly((any FlagCohort)? = nil)
/// Toggled remotely using PrivacyConfiguration but only for internal users. Otherwise, disabled.
case remoteDevelopment(PrivacyConfigFeatureLevel)
@@ -107,6 +168,44 @@ public protocol FeatureFlagger: AnyObject {
/// when the non-overridden feature flag value is required.
///
func isFeatureOn(for featureFlag: Flag, allowOverride: Bool) -> Bool
+
+ /// Retrieves the cohort for a feature flag if the feature is enabled.
+ ///
+ /// This method determines the source of the feature flag and evaluates its eligibility based on
+ /// the user's internal status and the privacy configuration. It supports different sources, such as
+ /// disabled features, internal-only features, and remotely toggled features.
+ ///
+ /// - Parameter featureFlag: A feature flag conforming to `FeatureFlagDescribing`.
+ ///
+ /// - Returns: The `CohortID` associated with the feature flag, or `nil` if the feature is disabled or
+ /// does not meet the eligibility criteria.
+ ///
+ /// - Behavior:
+ /// - For `.disabled`: Returns `nil`.
+ /// - For `.internalOnly`: Returns the cohort if the user is an internal user.
+ /// - For `.remoteDevelopment` and `.remoteReleasable`:
+ /// - If the feature is a subfeature, resolves its cohort using `getCohortIfEnabled(_ subfeature:)`.
+ /// - Returns `nil` if the user is not eligible.
+ ///
+ func getCohortIfEnabled(for featureFlag: Flag) -> (any FlagCohort)?
+
+ /// Retrieves all active experiments currently assigned to the user.
+ ///
+ /// This method iterates over the experiments stored in the `ExperimentManager` and checks their state
+ /// against the current `PrivacyConfiguration`. If an experiment's state is enabled or disabled due to
+ /// a target mismatch, and its assigned cohort matches the resolved cohort, it is considered active.
+ ///
+ /// - Returns: A dictionary of active experiments where the key is the experiment's subfeature ID,
+ /// and the value is the associated `ExperimentData`.
+ ///
+ /// - Behavior:
+ /// 1. Fetches all enrolled experiments from the `ExperimentManager`.
+ /// 2. For each experiment:
+ /// - Retrieves its state from the `PrivacyConfiguration`.
+ /// - Validates its assigned cohort using `resolveCohort` in the `ExperimentManager`.
+ /// 3. If the experiment passes validation, it is added to the result dictionary.
+ ///
+ func getAllActiveExperiments() -> Experiments
}
public extension FeatureFlagger {
@@ -126,14 +225,17 @@ public class DefaultFeatureFlagger: FeatureFlagger {
public let internalUserDecider: InternalUserDecider
public let privacyConfigManager: PrivacyConfigurationManaging
+ private let experimentManager: ExperimentCohortsManaging?
public let localOverrides: FeatureFlagLocalOverriding?
public init(
internalUserDecider: InternalUserDecider,
- privacyConfigManager: PrivacyConfigurationManaging
+ privacyConfigManager: PrivacyConfigurationManaging,
+ experimentManager: ExperimentCohortsManaging?
) {
self.internalUserDecider = internalUserDecider
self.privacyConfigManager = privacyConfigManager
+ self.experimentManager = experimentManager
self.localOverrides = nil
}
@@ -141,11 +243,13 @@ public class DefaultFeatureFlagger: FeatureFlagger {
internalUserDecider: InternalUserDecider,
privacyConfigManager: PrivacyConfigurationManaging,
localOverrides: FeatureFlagLocalOverriding,
+ experimentManager: ExperimentCohortsManaging?,
for: Flag.Type
) {
self.internalUserDecider = internalUserDecider
self.privacyConfigManager = privacyConfigManager
self.localOverrides = localOverrides
+ self.experimentManager = experimentManager
localOverrides.featureFlagger = self
// Clear all overrides if not an internal user
@@ -173,6 +277,58 @@ public class DefaultFeatureFlagger: FeatureFlagger {
}
}
+ public func getAllActiveExperiments() -> Experiments {
+ guard let enrolledExperiments = experimentManager?.experiments else { return [:] }
+ var activeExperiments = [String: ExperimentData]()
+ let config = privacyConfigManager.privacyConfig
+
+ for (subfeatureID, experimentData) in enrolledExperiments {
+ let state = config.stateFor(subfeatureID: subfeatureID, parentFeatureID: experimentData.parentID)
+ guard state == .enabled || state == .disabled(.targetDoesNotMatch) else { continue }
+ let cohorts = config.cohorts(subfeatureID: subfeatureID, parentFeatureID: experimentData.parentID) ?? []
+ let experimentSubfeature = ExperimentSubfeature(parentID: experimentData.parentID, subfeatureID: subfeatureID, cohorts: cohorts)
+
+ if experimentManager?.resolveCohort(for: experimentSubfeature, allowCohortReassignment: false) == experimentData.cohortID {
+ activeExperiments[subfeatureID] = experimentData
+ }
+ }
+ return activeExperiments
+ }
+
+ public func getCohortIfEnabled(for featureFlag: Flag) -> (any FlagCohort)? {
+ switch featureFlag.source {
+ case .disabled:
+ return nil
+ case .internalOnly(let cohort):
+ return cohort
+ case .remoteReleasable(let featureType),
+ .remoteDevelopment(let featureType) where internalUserDecider.isInternalUser:
+ if case .subfeature(let subfeature) = featureType {
+ if let resolvedCohortID = getCohortIfEnabled(subfeature) {
+ return Flag.CohortType.allCases.first { return $0.rawValue == resolvedCohortID }
+ }
+ }
+ return nil
+ default:
+ return nil
+ }
+ }
+
+ private func getCohortIfEnabled(_ subfeature: any PrivacySubfeature) -> CohortID? {
+ let config = privacyConfigManager.privacyConfig
+ let featureState = config.stateFor(subfeature)
+ let cohorts = config.cohorts(for: subfeature)
+ let experiment = ExperimentSubfeature(parentID: subfeature.parent.rawValue, subfeatureID: subfeature.rawValue, cohorts: cohorts ?? [])
+ switch featureState {
+ case .enabled:
+ return experimentManager?.resolveCohort(for: experiment, allowCohortReassignment: true)
+ case .disabled(.targetDoesNotMatch):
+ return experimentManager?.resolveCohort(for: experiment, allowCohortReassignment: false)
+ default:
+ return nil
+ }
+ }
+
private func isEnabled(_ featureType: PrivacyConfigFeatureLevel) -> Bool {
switch featureType {
case .feature(let feature):
diff --git a/Sources/BrowserServicesKit/PrivacyConfig/AppPrivacyConfiguration.swift b/Sources/BrowserServicesKit/PrivacyConfig/AppPrivacyConfiguration.swift
index 89e64093d..2390aff4d 100644
--- a/Sources/BrowserServicesKit/PrivacyConfig/AppPrivacyConfiguration.swift
+++ b/Sources/BrowserServicesKit/PrivacyConfig/AppPrivacyConfiguration.swift
@@ -33,19 +33,23 @@ public struct AppPrivacyConfiguration: PrivacyConfiguration {
private let locallyUnprotected: DomainsProtectionStore
private let internalUserDecider: InternalUserDecider
private let userDefaults: UserDefaults
+ private let locale: Locale
private let installDate: Date?
+ static let experimentManagerQueue = DispatchQueue(label: "com.experimentManager.queue")
public init(data: PrivacyConfigurationData,
identifier: String,
localProtection: DomainsProtectionStore,
internalUserDecider: InternalUserDecider,
userDefaults: UserDefaults = UserDefaults(),
+ locale: Locale = Locale.current,
installDate: Date? = nil) {
self.data = data
self.identifier = identifier
self.locallyUnprotected = localProtection
self.internalUserDecider = internalUserDecider
self.userDefaults = userDefaults
+ self.locale = locale
self.installDate = installDate
}
@@ -137,13 +141,14 @@ public struct AppPrivacyConfiguration: PrivacyConfiguration {
}
}
- private func isRolloutEnabled(subfeature: any PrivacySubfeature,
+ private func isRolloutEnabled(subfeatureID: SubfeatureID,
+ parentID: ParentFeatureID,
rolloutSteps: [PrivacyConfigurationData.PrivacyFeature.Feature.RolloutStep],
randomizer: (Range) -> Double) -> Bool {
// Empty rollouts should be default enabled
guard !rolloutSteps.isEmpty else { return true }
- let defsPrefix = "config.\(subfeature.parent.rawValue).\(subfeature.rawValue)"
+ let defsPrefix = "config.\(parentID).\(subfeatureID)"
if userDefaults.bool(forKey: "\(defsPrefix).\(Constants.enabledKey)") {
return true
}
@@ -182,7 +187,6 @@ public struct AppPrivacyConfiguration: PrivacyConfiguration {
public func isSubfeatureEnabled(_ subfeature: any PrivacySubfeature,
versionProvider: AppVersionProvider,
randomizer: (Range) -> Double) -> Bool {
-
switch stateFor(subfeature, versionProvider: versionProvider, randomizer: randomizer) {
case .enabled:
return true
@@ -191,17 +195,30 @@ public struct AppPrivacyConfiguration: PrivacyConfiguration {
}
}
- public func stateFor(_ subfeature: any PrivacySubfeature, versionProvider: AppVersionProvider, randomizer: (Range) -> Double) -> PrivacyConfigurationFeatureState {
+ public func stateFor(_ subfeature: any PrivacySubfeature,
+ versionProvider: AppVersionProvider,
+ randomizer: (Range) -> Double) -> PrivacyConfigurationFeatureState {
+ guard let subfeatureData = subfeatures(for: subfeature.parent)[subfeature.rawValue] else {
+ return .disabled(.featureMissing)
+ }
- let parentState = stateFor(featureKey: subfeature.parent, versionProvider: versionProvider)
- guard case .enabled = parentState else { return parentState }
+ return stateFor(subfeatureID: subfeature.rawValue, subfeatureData: subfeatureData, parentFeature: subfeature.parent, versionProvider: versionProvider, randomizer: randomizer)
+ }
- let subfeatures = subfeatures(for: subfeature.parent)
- let subfeatureData = subfeatures[subfeature.rawValue]
+ private func stateFor(subfeatureID: SubfeatureID,
+ subfeatureData: PrivacyConfigurationData.PrivacyFeature.Feature,
+ parentFeature: PrivacyFeature,
+ versionProvider: AppVersionProvider,
+ randomizer: (Range) -> Double) -> PrivacyConfigurationFeatureState {
+ // Step 1: Check parent feature state
+ let parentState = stateFor(featureKey: parentFeature, versionProvider: versionProvider)
+ guard case .enabled = parentState else { return parentState }
- let satisfiesMinVersion = satisfiesMinVersion(subfeatureData?.minSupportedVersion, versionProvider: versionProvider)
+ // Step 2: Check version
+ let satisfiesMinVersion = satisfiesMinVersion(subfeatureData.minSupportedVersion, versionProvider: versionProvider)
- switch subfeatureData?.state {
+ // Step 3: Check sub-feature state
+ switch subfeatureData.state {
case PrivacyConfigurationData.State.enabled:
guard satisfiesMinVersion else { return .disabled(.appVersionNotSupported) }
case PrivacyConfigurationData.State.internal:
@@ -210,15 +227,31 @@ public struct AppPrivacyConfiguration: PrivacyConfiguration {
default: return .disabled(.disabledInConfig)
}
- // Handle Rollouts
- if let rollout = subfeatureData?.rollout,
- !isRolloutEnabled(subfeature: subfeature, rolloutSteps: rollout.steps, randomizer: randomizer) {
+ // Step 4: Handle Rollouts
+ if let rollout = subfeatureData.rollout,
+ !isRolloutEnabled(subfeatureID: subfeatureID, parentID: parentFeature.rawValue, rolloutSteps: rollout.steps, randomizer: randomizer) {
return .disabled(.stillInRollout)
}
+ // Step 5: Check Targets
+ return checkTargets(subfeatureData)
+ }
+
+ private func checkTargets(_ subfeatureData: PrivacyConfigurationData.PrivacyFeature.Feature?) -> PrivacyConfigurationFeatureState {
+ // Check Targets
+ if let targets = subfeatureData?.targets, !matchTargets(targets: targets){
+ return .disabled(.targetDoesNotMatch)
+ }
return .enabled
}
+ private func matchTargets(targets: [PrivacyConfigurationData.PrivacyFeature.Feature.Target]) -> Bool {
+ targets.contains { target in
+ (target.localeCountry == nil || target.localeCountry == locale.regionCode) &&
+ (target.localeLanguage == nil || target.localeLanguage == locale.languageCode)
+ }
+ }
+
private func subfeatures(for feature: PrivacyFeature) -> PrivacyConfigurationData.PrivacyFeature.Features {
return data.features[feature.rawValue]?.features ?? [:]
}
@@ -247,7 +280,7 @@ public struct AppPrivacyConfiguration: PrivacyConfiguration {
guard let domain = domain else { return true }
return !isTempUnprotected(domain: domain) && !isUserUnprotected(domain: domain) &&
- !isInExceptionList(domain: domain, forFeature: .contentBlocking)
+ !isInExceptionList(domain: domain, forFeature: .contentBlocking)
}
public func isUserUnprotected(domain: String?) -> Bool {
@@ -297,7 +330,25 @@ public struct AppPrivacyConfiguration: PrivacyConfiguration {
public func userDisabledProtection(forDomain domain: String) {
locallyUnprotected.disableProtection(forDomain: domain.punycodeEncodedHostname.lowercased())
}
+}
+extension AppPrivacyConfiguration {
+
+ public func stateFor(subfeatureID: SubfeatureID, parentFeatureID: ParentFeatureID, versionProvider: AppVersionProvider,
+ randomizer: (Range) -> Double) -> PrivacyConfigurationFeatureState {
+ guard let parentFeature = PrivacyFeature(rawValue: parentFeatureID) else { return .disabled(.featureMissing) }
+ guard let subfeatureData = subfeatures(for: parentFeature)[subfeatureID] else { return .disabled(.featureMissing) }
+ return stateFor(subfeatureID: subfeatureID, subfeatureData: subfeatureData, parentFeature: parentFeature, versionProvider: versionProvider, randomizer: randomizer)
+ }
+
+ public func cohorts(for subfeature: any PrivacySubfeature) -> [PrivacyConfigurationData.Cohort]? {
+ subfeatures(for: subfeature.parent)[subfeature.rawValue]?.cohorts
+ }
+
+ public func cohorts(subfeatureID: SubfeatureID, parentFeatureID: ParentFeatureID) -> [PrivacyConfigurationData.Cohort]? {
+ guard let parentFeature = PrivacyFeature(rawValue: parentFeatureID) else { return nil }
+ return subfeatures(for: parentFeature)[subfeatureID]?.cohorts
+ }
}
extension Array where Element == String {
diff --git a/Sources/BrowserServicesKit/PrivacyConfig/ExperimentCohortsManager.swift b/Sources/BrowserServicesKit/PrivacyConfig/ExperimentCohortsManager.swift
deleted file mode 100644
index abd01290b..000000000
--- a/Sources/BrowserServicesKit/PrivacyConfig/ExperimentCohortsManager.swift
+++ /dev/null
@@ -1,107 +0,0 @@
-//
-// ExperimentCohortsManager.swift
-//
-// Copyright © 2024 DuckDuckGo. All rights reserved.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-//
-
-import Foundation
-
-struct ExperimentSubfeature {
- let subfeatureID: SubfeatureID
- let cohorts: [PrivacyConfigurationData.Cohort]
-}
-
-typealias CohortID = String
-typealias SubfeatureID = String
-
-struct ExperimentData: Codable, Equatable {
- let cohort: String
- let enrollmentDate: Date
-}
-
-typealias Experiments = [String: ExperimentData]
-
-protocol ExperimentCohortsManaging {
- /// Retrieves the cohort ID associated with the specified subfeature.
- /// - Parameter subfeatureID: The name of the experiment subfeature for which the cohort ID is needed.
- /// - Returns: The cohort ID as a `String` if one exists; otherwise, returns `nil`.
- func cohort(for subfeatureID: SubfeatureID) -> CohortID?
-
- /// Retrieves the enrollment date for the specified subfeature.
- /// - Parameter subfeatureID: The name of the experiment subfeature for which the enrollment date is needed.
- /// - Returns: The `Date` of enrollment if one exists; otherwise, returns `nil`.
- func enrollmentDate(for subfeatureID: SubfeatureID) -> Date?
-
- /// Assigns a cohort to the given subfeature based on defined weights and saves it to UserDefaults.
- /// - Parameter subfeature: The ExperimentSubfeature to which a cohort needs to be assigned to.
- /// - Returns: The name of the assigned cohort, or `nil` if no cohort could be assigned.
- func assignCohort(to subfeature: ExperimentSubfeature) -> CohortID?
-
- /// Removes the assigned cohort data for the specified subfeature.
- /// - Parameter subfeatureID: The name of the experiment subfeature for which the cohort data should be removed.
- func removeCohort(from subfeatureID: SubfeatureID)
-}
-
-final class ExperimentCohortsManager: ExperimentCohortsManaging {
-
- private var store: ExperimentsDataStoring
- private let randomizer: (Range) -> Double
-
- init(store: ExperimentsDataStoring = ExperimentsDataStore(), randomizer: @escaping (Range) -> Double) {
- self.store = store
- self.randomizer = randomizer
- }
-
- func cohort(for subfeatureID: SubfeatureID) -> CohortID? {
- guard let experiments = store.experiments else { return nil }
- return experiments[subfeatureID]?.cohort
- }
-
- func enrollmentDate(for subfeatureID: SubfeatureID) -> Date? {
- guard let experiments = store.experiments else { return nil }
- return experiments[subfeatureID]?.enrollmentDate
- }
-
- func assignCohort(to subfeature: ExperimentSubfeature) -> CohortID? {
- let cohorts = subfeature.cohorts
- let totalWeight = cohorts.map(\.weight).reduce(0, +)
- guard totalWeight > 0 else { return nil }
-
- let randomValue = randomizer(0..) -> Double) -> PrivacyConfigurationFeatureState
+
+ func cohorts(for subfeature: any PrivacySubfeature) -> [PrivacyConfigurationData.Cohort]?
+
+ func cohorts(subfeatureID: SubfeatureID, parentFeatureID: ParentFeatureID) -> [PrivacyConfigurationData.Cohort]?
}
public extension PrivacyConfiguration {
@@ -120,4 +129,9 @@ public extension PrivacyConfiguration {
func stateFor(_ subfeature: any PrivacySubfeature, randomizer: (Range) -> Double = Double.random(in:)) -> PrivacyConfigurationFeatureState {
return stateFor(subfeature, versionProvider: AppVersionProvider(), randomizer: randomizer)
}
+
+ func stateFor(subfeatureID: SubfeatureID, parentFeatureID: ParentFeatureID, randomizer: (Range) -> Double = Double.random(in:)) -> PrivacyConfigurationFeatureState {
+ return stateFor(subfeatureID: subfeatureID, parentFeatureID: parentFeatureID, versionProvider: AppVersionProvider(), randomizer: randomizer)
+ }
+
}
diff --git a/Sources/BrowserServicesKit/PrivacyConfig/PrivacyConfigurationData.swift b/Sources/BrowserServicesKit/PrivacyConfig/PrivacyConfigurationData.swift
index 7cbd2bc71..f6f91567f 100644
--- a/Sources/BrowserServicesKit/PrivacyConfig/PrivacyConfigurationData.swift
+++ b/Sources/BrowserServicesKit/PrivacyConfig/PrivacyConfigurationData.swift
@@ -136,6 +136,8 @@ public struct PrivacyConfigurationData {
case minSupportedVersion
case rollout
case cohorts
+ case targets
+ case settings
}
public struct Rollout: Hashable {
@@ -169,10 +171,27 @@ public struct PrivacyConfigurationData {
}
}
+ public struct Target {
+ enum CodingKeys: String {
+ case localeCountry
+ case localeLanguage
+ }
+
+ public let localeCountry: String?
+ public let localeLanguage: String?
+
+ public init(json: [String: Any]) {
+ self.localeCountry = json[CodingKeys.localeCountry.rawValue] as? String
+ self.localeLanguage = json[CodingKeys.localeLanguage.rawValue] as? String
+ }
+ }
+
public let state: FeatureState
public let minSupportedVersion: FeatureSupportedVersion?
public let rollout: Rollout?
public let cohorts: [Cohort]?
+ public let targets: [Target]?
+ public let settings: String?
public init?(json: [String: Any]) {
guard let state = json[CodingKeys.state.rawValue] as? String else {
@@ -194,6 +213,19 @@ public struct PrivacyConfigurationData {
} else {
cohorts = nil
}
+
+ if let targetData = json[CodingKeys.targets.rawValue] as? [[String: Any]] {
+ targets = targetData.compactMap { Target(json: $0) }
+ } else {
+ targets = nil
+ }
+
+ if let settingsData = json[CodingKeys.settings.rawValue] {
+ let jsonData = try? JSONSerialization.data(withJSONObject: settingsData, options: [])
+ settings = jsonData != nil ? String(data: jsonData!, encoding: .utf8) : nil
+ } else {
+ settings = nil
+ }
}
}
diff --git a/Sources/BrowserServicesKit/PrivacyConfig/PrivacyConfigurationManager.swift b/Sources/BrowserServicesKit/PrivacyConfig/PrivacyConfigurationManager.swift
index 219a01613..004729e70 100644
--- a/Sources/BrowserServicesKit/PrivacyConfig/PrivacyConfigurationManager.swift
+++ b/Sources/BrowserServicesKit/PrivacyConfig/PrivacyConfigurationManager.swift
@@ -55,6 +55,7 @@ public class PrivacyConfigurationManager: PrivacyConfigurationManaging {
private let localProtection: DomainsProtectionStore
private let errorReporting: EventMapping?
private let installDate: Date?
+ private let locale: Locale
public let internalUserDecider: InternalUserDecider
@@ -110,12 +111,14 @@ public class PrivacyConfigurationManager: PrivacyConfigurationManaging {
localProtection: DomainsProtectionStore,
errorReporting: EventMapping? = nil,
internalUserDecider: InternalUserDecider,
+ locale: Locale = Locale.current,
installDate: Date? = nil
) {
self.embeddedDataProvider = embeddedDataProvider
self.localProtection = localProtection
self.errorReporting = errorReporting
self.internalUserDecider = internalUserDecider
+ self.locale = locale
self.installDate = installDate
reload(etag: fetchedETag, data: fetchedData)
@@ -127,6 +130,7 @@ public class PrivacyConfigurationManager: PrivacyConfigurationManaging {
identifier: fetchedData.etag,
localProtection: localProtection,
internalUserDecider: internalUserDecider,
+ locale: locale,
installDate: installDate)
}
@@ -134,6 +138,7 @@ public class PrivacyConfigurationManager: PrivacyConfigurationManaging {
identifier: embeddedConfigData.etag,
localProtection: localProtection,
internalUserDecider: internalUserDecider,
+ locale: locale,
installDate: installDate)
}
diff --git a/Sources/Common/Concurrency/TaskExtension.swift b/Sources/Common/Concurrency/TaskExtension.swift
index a407e4406..65d974a36 100644
--- a/Sources/Common/Concurrency/TaskExtension.swift
+++ b/Sources/Common/Concurrency/TaskExtension.swift
@@ -18,51 +18,72 @@
import Foundation
+public struct Sleeper {
+
+ public static let `default` = Sleeper(sleep: {
+ try await Task.sleep(interval: $0)
+ })
+
+ private let sleep: (TimeInterval) async throws -> Void
+
+ public init(sleep: @escaping (TimeInterval) async throws -> Void) {
+ self.sleep = sleep
+ }
+
+ @available(macOS 13.0, iOS 16.0, *)
+ public init(clock: any Clock) {
+ self.sleep = { interval in
+ try await clock.sleep(for: .nanoseconds(UInt64(interval * Double(NSEC_PER_SEC))))
+ }
+ }
+
+ public func sleep(for interval: TimeInterval) async throws {
+ try await sleep(interval)
+ }
+
+}
+
+public func performPeriodicJob(withDelay delay: TimeInterval? = nil,
+ interval: TimeInterval,
+ sleeper: Sleeper = .default,
+ operation: @escaping @Sendable () async throws -> Void,
+ cancellationHandler: (@Sendable () async -> Void)? = nil) async throws -> Never {
+
+ do {
+ if let delay {
+ try await sleeper.sleep(for: delay)
+ }
+
+ repeat {
+ try await operation()
+
+ try await sleeper.sleep(for: interval)
+ } while true
+ } catch let error as CancellationError {
+ await cancellationHandler?()
+ throw error
+ }
+}
+
public extension Task where Success == Never, Failure == Error {
static func periodic(delay: TimeInterval? = nil,
interval: TimeInterval,
+ sleeper: Sleeper = .default,
operation: @escaping @Sendable () async -> Void,
cancellationHandler: (@Sendable () async -> Void)? = nil) -> Task {
- Task {
- do {
- if let delay {
- try await Task.sleep(interval: delay)
- }
-
- repeat {
- await operation()
-
- try await Task.sleep(interval: interval)
- } while true
- } catch {
- await cancellationHandler?()
- throw error
- }
- }
+ return periodic(delay: delay, interval: interval, sleeper: sleeper, operation: { await operation() } as @Sendable () async throws -> Void, cancellationHandler: cancellationHandler)
}
static func periodic(delay: TimeInterval? = nil,
interval: TimeInterval,
+ sleeper: Sleeper = .default,
operation: @escaping @Sendable () async throws -> Void,
cancellationHandler: (@Sendable () async -> Void)? = nil) -> Task {
Task {
- do {
- if let delay {
- try await Task.sleep(interval: delay)
- }
-
- repeat {
- try await operation()
-
- try await Task.sleep(interval: interval)
- } while true
- } catch {
- await cancellationHandler?()
- throw error
- }
+ try await performPeriodicJob(withDelay: delay, interval: interval, sleeper: sleeper, operation: operation, cancellationHandler: cancellationHandler)
}
}
}
diff --git a/Sources/Common/Extensions/DateExtension.swift b/Sources/Common/Extensions/DateExtension.swift
index fdabdee83..2a4ca74ea 100644
--- a/Sources/Common/Extensions/DateExtension.swift
+++ b/Sources/Common/Extensions/DateExtension.swift
@@ -64,7 +64,11 @@ public extension Date {
}
var startOfDay: Date {
- return Calendar.current.startOfDay(for: self)
+ return Calendar.current.startOfDay(for: self)
+ }
+
+ func daysAgo(_ days: Int) -> Date {
+ Calendar.current.date(byAdding: .day, value: -days, to: self)!
}
static var startOfMinuteNow: Date {
diff --git a/Sources/Common/Extensions/HashExtension.swift b/Sources/Common/Extensions/HashExtension.swift
index b6752cf57..13095cf63 100644
--- a/Sources/Common/Extensions/HashExtension.swift
+++ b/Sources/Common/Extensions/HashExtension.swift
@@ -42,8 +42,13 @@ extension Data {
extension String {
public var sha1: String {
- let dataBytes = data(using: .utf8)!
- return dataBytes.sha1
+ let result = utf8data.sha1
+ return result
+ }
+
+ public var sha256: String {
+ let result = utf8data.sha256
+ return result
}
}
diff --git a/Sources/Common/Extensions/StringExtension.swift b/Sources/Common/Extensions/StringExtension.swift
index 09050cfe2..9282a43b4 100644
--- a/Sources/Common/Extensions/StringExtension.swift
+++ b/Sources/Common/Extensions/StringExtension.swift
@@ -394,9 +394,9 @@ public extension String {
// MARK: Regex
- func matches(_ regex: NSRegularExpression) -> Bool {
- let matches = regex.matches(in: self, options: .anchored, range: self.fullRange)
- return matches.count == 1
+ func matches(_ regex: RegEx) -> Bool {
+ let firstMatch = firstMatch(of: regex, options: .anchored)
+ return firstMatch != nil
}
func matches(pattern: String, options: NSRegularExpression.Options = [.caseInsensitive]) -> Bool {
@@ -406,7 +406,7 @@ public extension String {
return matches(regex)
}
- func replacing(_ regex: NSRegularExpression, with replacement: String) -> String {
+ func replacing(_ regex: RegEx, with replacement: String) -> String {
regex.stringByReplacingMatches(in: self, range: self.fullRange, withTemplate: replacement)
}
diff --git a/Sources/Common/Extensions/URLExtension.swift b/Sources/Common/Extensions/URLExtension.swift
index d19751148..ce68773d5 100644
--- a/Sources/Common/Extensions/URLExtension.swift
+++ b/Sources/Common/Extensions/URLExtension.swift
@@ -354,22 +354,24 @@ extension URL {
// MARK: - Parameters
+ @_disfavoredOverload // prefer ordered KeyValuePairs collection when `parameters` passed as a Dictionary literal to preserve order.
public func appendingParameters(_ parameters: QueryParams, allowedReservedCharacters: CharacterSet? = nil) -> URL
where QueryParams.Element == (key: String, value: String) {
+ let result = self.appending(percentEncodedQueryItems: parameters.map { name, value in
+ URLQueryItem(percentEncodingName: name, value: value, withAllowedCharacters: allowedReservedCharacters)
+ })
+ return result
+ }
- return parameters.reduce(self) { partialResult, parameter in
- partialResult.appendingParameter(
- name: parameter.key,
- value: parameter.value,
- allowedReservedCharacters: allowedReservedCharacters
- )
- }
+ public func appendingParameters(_ parameters: KeyValuePairs, allowedReservedCharacters: CharacterSet? = nil) -> URL {
+ let result = self.appending(percentEncodedQueryItems: parameters.map { name, value in
+ URLQueryItem(percentEncodingName: name, value: value, withAllowedCharacters: allowedReservedCharacters)
+ })
+ return result
}
public func appendingParameter(name: String, value: String, allowedReservedCharacters: CharacterSet? = nil) -> URL {
- let queryItem = URLQueryItem(percentEncodingName: name,
- value: value,
- withAllowedCharacters: allowedReservedCharacters)
+ let queryItem = URLQueryItem(percentEncodingName: name, value: value, withAllowedCharacters: allowedReservedCharacters)
return self.appending(percentEncodedQueryItem: queryItem)
}
@@ -378,13 +380,15 @@ extension URL {
}
public func appending(percentEncodedQueryItems: [URLQueryItem]) -> URL {
- guard var components = URLComponents(url: self, resolvingAgainstBaseURL: true) else { return self }
+ guard !percentEncodedQueryItems.isEmpty,
+ var components = URLComponents(url: self, resolvingAgainstBaseURL: true) else { return self }
var existingPercentEncodedQueryItems = components.percentEncodedQueryItems ?? [URLQueryItem]()
existingPercentEncodedQueryItems.append(contentsOf: percentEncodedQueryItems)
components.percentEncodedQueryItems = existingPercentEncodedQueryItems
+ let result = components.url ?? self
- return components.url ?? self
+ return result
}
public func getQueryItems() -> [URLQueryItem]? {
diff --git a/Sources/MaliciousSiteProtection/API/APIClient.swift b/Sources/MaliciousSiteProtection/API/APIClient.swift
new file mode 100644
index 000000000..d16669d4b
--- /dev/null
+++ b/Sources/MaliciousSiteProtection/API/APIClient.swift
@@ -0,0 +1,129 @@
+//
+// APIClient.swift
+//
+// Copyright © 2024 DuckDuckGo. All rights reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//
+
+import Common
+import Foundation
+import Networking
+
+extension APIClient {
+ // used internally for testing
+ protocol Mockable {
+ func load(_ requestConfig: Request) async throws -> Request.Response
+ }
+}
+extension APIClient: APIClient.Mockable {}
+
+public protocol APIClientEnvironment {
+ func headers(for requestType: APIRequestType) -> APIRequestV2.HeadersV2
+ func url(for requestType: APIRequestType) -> URL
+}
+
+public extension MaliciousSiteDetector {
+ enum APIEnvironment: APIClientEnvironment {
+
+ case production
+ case staging
+
+ var endpoint: URL {
+ switch self {
+ case .production: URL(string: "https://duckduckgo.com/api/protection/")!
+ case .staging: URL(string: "https://staging.duckduckgo.com/api/protection/")!
+ }
+ }
+
+ var defaultHeaders: APIRequestV2.HeadersV2 {
+ .init(userAgent: Networking.APIRequest.Headers.userAgent)
+ }
+
+ enum APIPath {
+ static let filterSet = "filterSet"
+ static let hashPrefix = "hashPrefix"
+ static let matches = "matches"
+ }
+
+ enum QueryParameter {
+ static let category = "category"
+ static let revision = "revision"
+ static let hashPrefix = "hashPrefix"
+ }
+
+ public func url(for requestType: APIRequestType) -> URL {
+ switch requestType {
+ case .hashPrefixSet(let configuration):
+ endpoint.appendingPathComponent(APIPath.hashPrefix).appendingParameters([
+ QueryParameter.category: configuration.threatKind.rawValue,
+ QueryParameter.revision: (configuration.revision ?? 0).description,
+ ])
+ case .filterSet(let configuration):
+ endpoint.appendingPathComponent(APIPath.filterSet).appendingParameters([
+ QueryParameter.category: configuration.threatKind.rawValue,
+ QueryParameter.revision: (configuration.revision ?? 0).description,
+ ])
+ case .matches(let configuration):
+ endpoint.appendingPathComponent(APIPath.matches).appendingParameter(name: QueryParameter.hashPrefix, value: configuration.hashPrefix)
+ }
+ }
+
+ public func headers(for requestType: APIRequestType) -> APIRequestV2.HeadersV2 {
+ defaultHeaders
+ }
+ }
+
+}
+
+struct APIClient {
+
+ let environment: APIClientEnvironment
+ private let service: APIService
+
+ init(environment: APIClientEnvironment, service: APIService = DefaultAPIService(urlSession: .shared)) {
+ self.environment = environment
+ self.service = service
+ }
+
+ func load(_ requestConfig: R) async throws -> R.Response {
+ let requestType = requestConfig.requestType
+ let headers = environment.headers(for: requestType)
+ let url = environment.url(for: requestType)
+
+ let apiRequest = APIRequestV2(url: url, method: .get, headers: headers, timeoutInterval: requestConfig.timeout ?? 60)
+ let response = try await service.fetch(request: apiRequest)
+ let result: R.Response = try response.decodeBody()
+
+ return result
+ }
+
+}
+
+// MARK: - Convenience
+extension APIClient.Mockable {
+ func filtersChangeSet(for threatKind: ThreatKind, revision: Int) async throws -> APIClient.Response.FiltersChangeSet {
+ let result = try await load(.filterSet(threatKind: threatKind, revision: revision))
+ return result
+ }
+
+ func hashPrefixesChangeSet(for threatKind: ThreatKind, revision: Int) async throws -> APIClient.Response.HashPrefixesChangeSet {
+ let result = try await load(.hashPrefixes(threatKind: threatKind, revision: revision))
+ return result
+ }
+
+ func matches(forHashPrefix hashPrefix: String) async throws -> APIClient.Response.Matches {
+ let result = try await load(.matches(hashPrefix: hashPrefix))
+ return result
+ }
+}
diff --git a/Sources/MaliciousSiteProtection/API/APIRequest.swift b/Sources/MaliciousSiteProtection/API/APIRequest.swift
new file mode 100644
index 000000000..39fb623bd
--- /dev/null
+++ b/Sources/MaliciousSiteProtection/API/APIRequest.swift
@@ -0,0 +1,112 @@
+//
+// APIRequest.swift
+//
+// Copyright © 2024 DuckDuckGo. All rights reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//
+
+import Foundation
+
+// Enumerated request type to delegate URLs forming to an API environment instance
+public enum APIRequestType {
+ case hashPrefixSet(APIRequestType.HashPrefixes)
+ case filterSet(APIRequestType.FilterSet)
+ case matches(APIRequestType.Matches)
+}
+
+extension APIClient {
+ // Protocol for defining typed requests with a specific response type.
+ protocol Request {
+ associatedtype Response: Decodable // Strongly-typed response type
+ var requestType: APIRequestType { get } // Enumerated type of request being made
+ var timeout: TimeInterval? { get }
+ }
+
+ // Protocol for requests that modify a set of malicious site detection data
+ // (returning insertions/removals along with the updated revision)
+ protocol ChangeSetRequest: Request {
+ init(threatKind: ThreatKind, revision: Int?)
+ }
+}
+extension APIClient.Request {
+ var timeout: TimeInterval? { nil }
+}
+
+public extension APIRequestType {
+ struct HashPrefixes: APIClient.ChangeSetRequest {
+ typealias Response = APIClient.Response.HashPrefixesChangeSet
+
+ let threatKind: ThreatKind
+ let revision: Int?
+
+ init(threatKind: ThreatKind, revision: Int?) {
+ self.threatKind = threatKind
+ self.revision = revision
+ }
+
+ var requestType: APIRequestType {
+ .hashPrefixSet(self)
+ }
+ }
+}
+/// extension to call generic `load(_: some Request)` method like this: `load(.hashPrefixes(…))`
+extension APIClient.Request where Self == APIRequestType.HashPrefixes {
+ static func hashPrefixes(threatKind: ThreatKind, revision: Int?) -> Self {
+ .init(threatKind: threatKind, revision: revision)
+ }
+}
+
+public extension APIRequestType {
+ struct FilterSet: APIClient.ChangeSetRequest {
+ typealias Response = APIClient.Response.FiltersChangeSet
+
+ let threatKind: ThreatKind
+ let revision: Int?
+
+ init(threatKind: ThreatKind, revision: Int?) {
+ self.threatKind = threatKind
+ self.revision = revision
+ }
+
+ var requestType: APIRequestType {
+ .filterSet(self)
+ }
+ }
+}
+/// extension to call generic `load(_: some Request)` method like this: `load(.filterSet(…))`
+extension APIClient.Request where Self == APIRequestType.FilterSet {
+ static func filterSet(threatKind: ThreatKind, revision: Int?) -> Self {
+ .init(threatKind: threatKind, revision: revision)
+ }
+}
+
+public extension APIRequestType {
+ struct Matches: APIClient.Request {
+ typealias Response = APIClient.Response.Matches
+
+ let hashPrefix: String
+
+ var requestType: APIRequestType {
+ .matches(self)
+ }
+
+ var timeout: TimeInterval? { 1 }
+ }
+}
+/// extension to call generic `load(_: some Request)` method like this: `load(.matches(…))`
+extension APIClient.Request where Self == APIRequestType.Matches {
+ static func matches(hashPrefix: String) -> Self {
+ .init(hashPrefix: hashPrefix)
+ }
+}
diff --git a/Sources/MaliciousSiteProtection/API/ChangeSetResponse.swift b/Sources/MaliciousSiteProtection/API/ChangeSetResponse.swift
new file mode 100644
index 000000000..eaf4f287c
--- /dev/null
+++ b/Sources/MaliciousSiteProtection/API/ChangeSetResponse.swift
@@ -0,0 +1,47 @@
+//
+// ChangeSetResponse.swift
+//
+// Copyright © 2024 DuckDuckGo. All rights reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//
+
+import Foundation
+
+extension APIClient {
+
+ public struct ChangeSetResponse: Codable, Equatable {
+ let insert: [T]
+ let delete: [T]
+ let revision: Int
+ let replace: Bool
+
+ public init(insert: [T], delete: [T], revision: Int, replace: Bool) {
+ self.insert = insert
+ self.delete = delete
+ self.revision = revision
+ self.replace = replace
+ }
+
+ public var isEmpty: Bool {
+ insert.isEmpty && delete.isEmpty
+ }
+ }
+
+ public enum Response {
+ public typealias FiltersChangeSet = ChangeSetResponse
+ public typealias HashPrefixesChangeSet = ChangeSetResponse
+ public typealias Matches = MatchResponse
+ }
+
+}
diff --git a/Sources/MaliciousSiteProtection/API/MatchResponse.swift b/Sources/MaliciousSiteProtection/API/MatchResponse.swift
new file mode 100644
index 000000000..2cb6df962
--- /dev/null
+++ b/Sources/MaliciousSiteProtection/API/MatchResponse.swift
@@ -0,0 +1,29 @@
+//
+// MatchResponse.swift
+//
+// Copyright © 2024 DuckDuckGo. All rights reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//
+
+extension APIClient {
+
+ public struct MatchResponse: Codable, Equatable {
+ public var matches: [Match]
+
+ public init(matches: [Match]) {
+ self.matches = matches
+ }
+ }
+
+}
diff --git a/Sources/Networking/v2/Extensions/Dictionary+URLQueryItem.swift b/Sources/MaliciousSiteProtection/Logger+MaliciousSiteProtection.swift
similarity index 53%
rename from Sources/Networking/v2/Extensions/Dictionary+URLQueryItem.swift
rename to Sources/MaliciousSiteProtection/Logger+MaliciousSiteProtection.swift
index 81a4648d6..3e44f3bcd 100644
--- a/Sources/Networking/v2/Extensions/Dictionary+URLQueryItem.swift
+++ b/Sources/MaliciousSiteProtection/Logger+MaliciousSiteProtection.swift
@@ -1,5 +1,5 @@
//
-// Dictionary+URLQueryItem.swift
+// Logger+MaliciousSiteProtection.swift
//
// Copyright © 2024 DuckDuckGo. All rights reserved.
//
@@ -17,19 +17,15 @@
//
import Foundation
-import Common
+import os
-extension Dictionary where Key == String, Value == String {
-
- func toURLQueryItems(allowedReservedCharacters: CharacterSet? = nil) -> [URLQueryItem] {
- return self.map {
- if let allowedReservedCharacters {
- URLQueryItem(percentEncodingName: $0.key,
- value: $0.value,
- withAllowedCharacters: allowedReservedCharacters)
- } else {
- URLQueryItem(name: $0.key, value: $0.value)
- }
- }
+public extension os.Logger {
+ struct MaliciousSiteProtection {
+ public static var general = os.Logger(subsystem: "MSP", category: "General")
+ public static var api = os.Logger(subsystem: "MSP", category: "API")
+ public static var dataManager = os.Logger(subsystem: "MSP", category: "DataManager")
+ public static var updateManager = os.Logger(subsystem: "MSP", category: "UpdateManager")
}
}
+
+internal typealias Logger = os.Logger.MaliciousSiteProtection
diff --git a/Sources/MaliciousSiteProtection/MaliciousSiteDetector.swift b/Sources/MaliciousSiteProtection/MaliciousSiteDetector.swift
new file mode 100644
index 000000000..1a637a73a
--- /dev/null
+++ b/Sources/MaliciousSiteProtection/MaliciousSiteDetector.swift
@@ -0,0 +1,130 @@
+//
+// MaliciousSiteDetector.swift
+//
+// Copyright © 2024 DuckDuckGo. All rights reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//
+
+import Common
+import CryptoKit
+import Foundation
+import Networking
+
+public protocol MaliciousSiteDetecting {
+ /// Evaluates the given URL to determine its malicious category (e.g., phishing, malware).
+ /// - Parameter url: The URL to evaluate.
+ /// - Returns: An optional `ThreatKind` indicating the type of threat, or `.none` if no threat is detected.
+ func evaluate(_ url: URL) async -> ThreatKind?
+}
+
+/// Class responsible for detecting malicious sites by evaluating URLs against local filters and an external API.
+/// entry point: `func evaluate(_: URL) async -> ThreatKind?`
+public final class MaliciousSiteDetector: MaliciousSiteDetecting {
+ // Type aliases for easier symbol navigation in Xcode.
+ typealias PhishingDetector = MaliciousSiteDetector
+ typealias MalwareDetector = MaliciousSiteDetector
+
+ private enum Constants {
+ static let hashPrefixStoreLength: Int = 8
+ static let hashPrefixParamLength: Int = 4
+ }
+
+ private let apiClient: APIClient.Mockable
+ private let dataManager: DataManaging
+ private let eventMapping: EventMapping
+
+ public convenience init(apiEnvironment: APIClientEnvironment, service: APIService = DefaultAPIService(urlSession: .shared), dataManager: DataManager, eventMapping: EventMapping) {
+ self.init(apiClient: APIClient(environment: apiEnvironment, service: service), dataManager: dataManager, eventMapping: eventMapping)
+ }
+
+ init(apiClient: APIClient.Mockable, dataManager: DataManaging, eventMapping: EventMapping) {
+ self.apiClient = apiClient
+ self.dataManager = dataManager
+ self.eventMapping = eventMapping
+ }
+
+ private func checkLocalFilters(hostHash: String, canonicalUrl: URL, for threatKind: ThreatKind) async -> Bool {
+ let filterSet = await dataManager.dataSet(for: .filterSet(threatKind: threatKind))
+ let matchesLocalFilters = filterSet[hostHash]?.contains(where: { regex in
+ canonicalUrl.absoluteString.matches(pattern: regex)
+ }) ?? false
+
+ return matchesLocalFilters
+ }
+
+ private func checkApiMatches(hostHash: String, canonicalUrl: URL) async -> Match? {
+ let hashPrefixParam = String(hostHash.prefix(Constants.hashPrefixParamLength))
+ let matches: [Match]
+ do {
+ matches = try await apiClient.matches(forHashPrefix: hashPrefixParam).matches
+ } catch {
+ Logger.general.error("Error fetching matches from API: \(error)")
+ return nil
+ }
+
+ if let match = matches.first(where: { match in
+ match.hash == hostHash && canonicalUrl.absoluteString.matches(pattern: match.regex)
+ }) {
+ return match
+ }
+ return nil
+ }
+
+ /// Evaluates the given URL to determine its malicious category (e.g., phishing, malware).
+ public func evaluate(_ url: URL) async -> ThreatKind? {
+ guard let canonicalHost = url.canonicalHost(),
+ let canonicalUrl = url.canonicalURL() else { return .none }
+
+ let hostHash = canonicalHost.sha256
+ let hashPrefix = String(hostHash.prefix(Constants.hashPrefixStoreLength))
+
+ // 1. Check for matching hash prefixes.
+ // The hash prefix list serves as a representation of the entire database:
+ // every malicious website will have a hash prefix that it collides with.
+ var hashPrefixMatchingThreatKinds = [ThreatKind]()
+ for threatKind in ThreatKind.allCases { // e.g., phishing, malware, etc.
+ let hashPrefixes = await dataManager.dataSet(for: .hashPrefixes(threatKind: threatKind))
+ if hashPrefixes.contains(hashPrefix) {
+ hashPrefixMatchingThreatKinds.append(threatKind)
+ }
+ }
+
+ // Return no threats if no matching hash prefixes are found in the database.
+ guard !hashPrefixMatchingThreatKinds.isEmpty else { return .none }
+
+ // 2. Check local Filter Sets.
+ // The filter set acts as a local cache of some database entries, containing
+ // the 5000 most common threats (or those most likely to collide with daily
+ // browsing behaviors, based on Clickhouse's top 10k, ranked by Netcraft's risk rating).
+ for threatKind in hashPrefixMatchingThreatKinds {
+ let matches = await checkLocalFilters(hostHash: hostHash, canonicalUrl: canonicalUrl, for: threatKind)
+ if matches {
+ eventMapping.fire(.errorPageShown(clientSideHit: true, threatKind: threatKind))
+ return threatKind
+ }
+ }
+
+ // 3. If no locally cached filters matched, we will still make a request to the API
+ // to check for potential matches on our backend.
+ let match = await checkApiMatches(hostHash: hostHash, canonicalUrl: canonicalUrl)
+ if let match {
+ let threatKind = match.category.flatMap(ThreatKind.init) ?? hashPrefixMatchingThreatKinds[0]
+ eventMapping.fire(.errorPageShown(clientSideHit: false, threatKind: threatKind))
+ return threatKind
+ }
+
+ return .none
+ }
+
+}
diff --git a/Sources/PhishingDetection/PhishingDetectionEvents.swift b/Sources/MaliciousSiteProtection/Model/Event.swift
similarity index 92%
rename from Sources/PhishingDetection/PhishingDetectionEvents.swift
rename to Sources/MaliciousSiteProtection/Model/Event.swift
index a788e09ff..8903f4d70 100644
--- a/Sources/PhishingDetection/PhishingDetectionEvents.swift
+++ b/Sources/MaliciousSiteProtection/Model/Event.swift
@@ -1,5 +1,5 @@
//
-// PhishingDetectionEvents.swift
+// Event.swift
//
// Copyright © 2024 DuckDuckGo. All rights reserved.
//
@@ -26,8 +26,8 @@ public extension PixelKit {
}
}
-public enum PhishingDetectionEvents: PixelKitEventV2 {
- case errorPageShown(clientSideHit: Bool)
+public enum Event: PixelKitEventV2 {
+ case errorPageShown(clientSideHit: Bool, threatKind: ThreatKind)
case visitSite
case iframeLoaded
case updateTaskFailed48h(error: Error?)
@@ -50,7 +50,7 @@ public enum PhishingDetectionEvents: PixelKitEventV2 {
public var parameters: [String: String]? {
switch self {
- case .errorPageShown(let clientSideHit):
+ case .errorPageShown(let clientSideHit, threatKind: _):
return [PixelKit.Parameters.clientSideHit: String(clientSideHit)]
case .visitSite:
return [:]
diff --git a/Tests/PhishingDetectionTests/Mocks/BackgroundActivitySchedulerMock.swift b/Sources/MaliciousSiteProtection/Model/Filter.swift
similarity index 65%
rename from Tests/PhishingDetectionTests/Mocks/BackgroundActivitySchedulerMock.swift
rename to Sources/MaliciousSiteProtection/Model/Filter.swift
index 86b79d477..674a176e0 100644
--- a/Tests/PhishingDetectionTests/Mocks/BackgroundActivitySchedulerMock.swift
+++ b/Sources/MaliciousSiteProtection/Model/Filter.swift
@@ -1,5 +1,5 @@
//
-// BackgroundActivitySchedulerMock.swift
+// Filter.swift
//
// Copyright © 2024 DuckDuckGo. All rights reserved.
//
@@ -17,19 +17,18 @@
//
import Foundation
-import PhishingDetection
-actor MockBackgroundActivityScheduler: BackgroundActivityScheduling {
- var startCalled = false
- var stopCalled = false
- var interval: TimeInterval = 1
- var identifier: String = "test"
+public struct Filter: Codable, Hashable {
+ public var hash: String
+ public var regex: String
- func start() {
- startCalled = true
+ enum CodingKeys: String, CodingKey {
+ case hash
+ case regex
}
- func stop() {
- stopCalled = true
+ public init(hash: String, regex: String) {
+ self.hash = hash
+ self.regex = regex
}
}
diff --git a/Sources/MaliciousSiteProtection/Model/FilterDictionary.swift b/Sources/MaliciousSiteProtection/Model/FilterDictionary.swift
new file mode 100644
index 000000000..b67cd82ef
--- /dev/null
+++ b/Sources/MaliciousSiteProtection/Model/FilterDictionary.swift
@@ -0,0 +1,78 @@
+//
+// FilterDictionary.swift
+//
+// Copyright © 2024 DuckDuckGo. All rights reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//
+
+import Foundation
+
+struct FilterDictionary: Codable, Equatable {
+
+ /// Filter set revision
+ var revision: Int
+
+ /// [Hash: [RegEx]] mapping
+ ///
+ /// - **Key**: SHA256 hash sum of a canonical host name
+ /// - **Value**: An array of regex patterns used to match whole URLs
+ ///
+ /// ```
+ /// {
+ /// "3aeb002460381c6f258e8395d3026f571f0d9a76488dcd837639b13aed316560" : [
+ /// "(?i)^https?\\:\\/\\/[\\w\\-\\.]+(?:\\:(?:80|443))?[\\/\\\\]+BETS1O\\-GIRIS[\\/\\\\]+BETS1O(?:[\\/\\\\]+|\\?|$)"
+ /// ],
+ /// ...
+ /// }
+ /// ```
+ var filters: [String: Set]
+
+ /// Subscript to access regex patterns by SHA256 host name hash
+ subscript(hash: String) -> Set? {
+ filters[hash]
+ }
+
+ mutating func subtract(_ itemsToDelete: Seq) where Seq.Element == Filter {
+ for filter in itemsToDelete {
+ // Remove the filter from the Set stored in the Dictionary by hash used as a key.
+ // If the Set becomes empty – remove the Set value from the Dictionary.
+ //
+ // The following code is equivalent to this one but without the Set value being copied
+ // or key being searched multiple times:
+ /*
+ if var filterSet = self.filters[filter.hash] {
+ filterSet.remove(filter.regex)
+ if filterSet.isEmpty {
+ self.filters[filter.hash] = nil
+ } else {
+ self.filters[filter.hash] = filterSet
+ }
+ }
+ */
+ withUnsafeMutablePointer(to: &filters[filter.hash]) { item in
+ item.pointee?.remove(filter.regex)
+ if item.pointee?.isEmpty == true {
+ item.pointee = nil
+ }
+ }
+ }
+ }
+
+ mutating func formUnion(_ itemsToAdd: Seq) where Seq.Element == Filter {
+ for filter in itemsToAdd {
+ filters[filter.hash, default: []].insert(filter.regex)
+ }
+ }
+
+}
diff --git a/Sources/MaliciousSiteProtection/Model/HashPrefixSet.swift b/Sources/MaliciousSiteProtection/Model/HashPrefixSet.swift
new file mode 100644
index 000000000..7aec5244d
--- /dev/null
+++ b/Sources/MaliciousSiteProtection/Model/HashPrefixSet.swift
@@ -0,0 +1,45 @@
+//
+// HashPrefixSet.swift
+//
+// Copyright © 2024 DuckDuckGo. All rights reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//
+
+import Foundation
+
+/// Structure storing a Set of hash prefixes ["6fe1e7c8","1d760415",...] and a revision of the set.
+struct HashPrefixSet: Codable, Equatable {
+
+ var revision: Int
+ var set: Set
+
+ init(revision: Int, items: some Sequence) {
+ self.revision = revision
+ self.set = Set(items)
+ }
+
+ mutating func subtract(_ itemsToDelete: Seq) where Seq.Element == String {
+ set.subtract(itemsToDelete)
+ }
+
+ mutating func formUnion(_ itemsToAdd: Seq) where Seq.Element == String {
+ set.formUnion(itemsToAdd)
+ }
+
+ @inline(__always)
+ func contains(_ item: String) -> Bool {
+ set.contains(item)
+ }
+
+}
diff --git a/Sources/MaliciousSiteProtection/Model/IncrementallyUpdatableDataSet.swift b/Sources/MaliciousSiteProtection/Model/IncrementallyUpdatableDataSet.swift
new file mode 100644
index 000000000..8a23785ae
--- /dev/null
+++ b/Sources/MaliciousSiteProtection/Model/IncrementallyUpdatableDataSet.swift
@@ -0,0 +1,71 @@
+//
+// IncrementallyUpdatableDataSet.swift
+//
+// Copyright © 2024 DuckDuckGo. All rights reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//
+
+protocol IncrementallyUpdatableDataSet: Codable, Equatable {
+ /// Set Element Type (Hash Prefix or Filter)
+ associatedtype Element: Codable, Hashable
+ /// API Request type used to fetch updates for the data set
+ associatedtype APIRequest: APIClient.ChangeSetRequest where APIRequest.Response == APIClient.ChangeSetResponse
+
+ var revision: Int { get set }
+
+ init(revision: Int, items: some Sequence)
+
+ mutating func subtract(_ itemsToDelete: Seq) where Seq.Element == Element
+ mutating func formUnion(_ itemsToAdd: Seq) where Seq.Element == Element
+
+ /// Apply ChangeSet from local data revision to actual revision loaded from API
+ mutating func apply(_ changeSet: APIClient.ChangeSetResponse)
+}
+
+extension IncrementallyUpdatableDataSet {
+ mutating func apply(_ changeSet: APIClient.ChangeSetResponse) {
+ if changeSet.replace {
+ self = .init(revision: changeSet.revision, items: changeSet.insert)
+ } else {
+ self.subtract(changeSet.delete)
+ self.formUnion(changeSet.insert)
+ self.revision = changeSet.revision
+ }
+ }
+}
+
+extension HashPrefixSet: IncrementallyUpdatableDataSet {
+ typealias Element = String
+ typealias APIRequest = APIRequestType.HashPrefixes
+
+ static func apiRequest(for threatKind: ThreatKind, revision: Int) -> APIRequest {
+ .hashPrefixes(threatKind: threatKind, revision: revision)
+ }
+}
+
+extension FilterDictionary: IncrementallyUpdatableDataSet {
+ typealias Element = Filter
+ typealias APIRequest = APIRequestType.FilterSet
+
+ init(revision: Int, items: some Sequence) {
+ let filtersDictionary = items.reduce(into: [String: Set]()) { result, filter in
+ result[filter.hash, default: []].insert(filter.regex)
+ }
+ self.init(revision: revision, filters: filtersDictionary)
+ }
+
+ static func apiRequest(for threatKind: ThreatKind, revision: Int) -> APIRequest {
+ .filterSet(threatKind: threatKind, revision: revision)
+ }
+}
diff --git a/Sources/MaliciousSiteProtection/Model/LoadableFromEmbeddedData.swift b/Sources/MaliciousSiteProtection/Model/LoadableFromEmbeddedData.swift
new file mode 100644
index 000000000..be67cb6fc
--- /dev/null
+++ b/Sources/MaliciousSiteProtection/Model/LoadableFromEmbeddedData.swift
@@ -0,0 +1,34 @@
+//
+// LoadableFromEmbeddedData.swift
+//
+// Copyright © 2024 DuckDuckGo. All rights reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//
+
+public protocol LoadableFromEmbeddedData {
+ /// Set Element Type (Hash Prefix or Filter)
+ associatedtype Element
+ /// Decoded data type stored in the embedded json file
+ associatedtype EmbeddedDataSet: Decodable, Sequence where EmbeddedDataSet.Element == Self.Element
+
+ init(revision: Int, items: some Sequence)
+}
+
+extension HashPrefixSet: LoadableFromEmbeddedData {
+ public typealias EmbeddedDataSet = [String]
+}
+
+extension FilterDictionary: LoadableFromEmbeddedData {
+ public typealias EmbeddedDataSet = [Filter]
+}
diff --git a/Sources/MaliciousSiteProtection/Model/MaliciousSiteError.swift b/Sources/MaliciousSiteProtection/Model/MaliciousSiteError.swift
new file mode 100644
index 000000000..8da2523f5
--- /dev/null
+++ b/Sources/MaliciousSiteProtection/Model/MaliciousSiteError.swift
@@ -0,0 +1,94 @@
+//
+// MaliciousSiteError.swift
+//
+// Copyright © 2024 DuckDuckGo. All rights reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//
+
+import Foundation
+
+public struct MaliciousSiteError: Error, Equatable {
+
+ public enum Code: Int {
+ case phishing = 1
+ // case malware = 2
+ }
+ public let code: Code
+ public let failingUrl: URL
+
+ public init(code: Code, failingUrl: URL) {
+ self.code = code
+ self.failingUrl = failingUrl
+ }
+
+ public init(threat: ThreatKind, failingUrl: URL) {
+ let code: Code
+ switch threat {
+ case .phishing:
+ code = .phishing
+ // case .malware:
+ // code = .malware
+ }
+ self.init(code: code, failingUrl: failingUrl)
+ }
+
+ public var threatKind: ThreatKind {
+ switch code {
+ case .phishing: .phishing
+ // case .malware: .malware
+ }
+ }
+
+}
+
+extension MaliciousSiteError: _ObjectiveCBridgeableError {
+
+ public init?(_bridgedNSError error: NSError) {
+ guard error.domain == MaliciousSiteError.errorDomain,
+ let code = Code(rawValue: error.code),
+ let failingUrl = error.userInfo[NSURLErrorFailingURLErrorKey] as? URL else { return nil }
+ self.code = code
+ self.failingUrl = failingUrl
+ }
+
+}
+
+extension MaliciousSiteError: LocalizedError {
+
+ public var errorDescription: String? {
+ switch code {
+ case .phishing:
+ return "Phishing detected"
+ // case .malware:
+ // return "Malware detected"
+ }
+ }
+
+}
+
+extension MaliciousSiteError: CustomNSError {
+ public static let errorDomain: String = "MaliciousSiteError"
+
+ public var errorCode: Int {
+ code.rawValue
+ }
+
+ public var errorUserInfo: [String: Any] {
+ [
+ NSURLErrorFailingURLErrorKey: failingUrl,
+ NSLocalizedDescriptionKey: errorDescription!
+ ]
+ }
+
+}
diff --git a/Sources/MaliciousSiteProtection/Model/Match.swift b/Sources/MaliciousSiteProtection/Model/Match.swift
new file mode 100644
index 000000000..e22cb597f
--- /dev/null
+++ b/Sources/MaliciousSiteProtection/Model/Match.swift
@@ -0,0 +1,35 @@
+//
+// Match.swift
+//
+// Copyright © 2024 DuckDuckGo. All rights reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//
+
+import Foundation
+
+public struct Match: Codable, Hashable {
+ var hostname: String
+ var url: String
+ var regex: String
+ var hash: String
+ let category: String?
+
+ public init(hostname: String, url: String, regex: String, hash: String, category: String?) {
+ self.hostname = hostname
+ self.url = url
+ self.regex = regex
+ self.hash = hash
+ self.category = category
+ }
+}
diff --git a/Sources/MaliciousSiteProtection/Model/StoredData.swift b/Sources/MaliciousSiteProtection/Model/StoredData.swift
new file mode 100644
index 000000000..a064be076
--- /dev/null
+++ b/Sources/MaliciousSiteProtection/Model/StoredData.swift
@@ -0,0 +1,104 @@
+//
+// StoredData.swift
+//
+// Copyright © 2024 DuckDuckGo. All rights reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//
+
+import Foundation
+
+protocol MaliciousSiteDataKey: Hashable {
+ associatedtype EmbeddedDataSet: Decodable
+ associatedtype DataSet: IncrementallyUpdatableDataSet, LoadableFromEmbeddedData
+
+ var dataType: DataManager.StoredDataType { get }
+ var threatKind: ThreatKind { get }
+}
+
+public extension DataManager {
+ enum StoredDataType: Hashable, CaseIterable {
+ case hashPrefixSet(HashPrefixes)
+ case filterSet(FilterSet)
+
+ enum Kind: CaseIterable {
+ case hashPrefixSet, filterSet
+ }
+ // keep to get a compiler error when number of cases changes
+ var kind: Kind {
+ switch self {
+ case .hashPrefixSet: .hashPrefixSet
+ case .filterSet: .filterSet
+ }
+ }
+
+ var dataKey: any MaliciousSiteDataKey {
+ switch self {
+ case .hashPrefixSet(let key): key
+ case .filterSet(let key): key
+ }
+ }
+
+ public var threatKind: ThreatKind {
+ switch self {
+ case .hashPrefixSet(let key): key.threatKind
+ case .filterSet(let key): key.threatKind
+ }
+ }
+
+ public static var allCases: [DataManager.StoredDataType] {
+ ThreatKind.allCases.map { threatKind in
+ Kind.allCases.map { dataKind in
+ switch dataKind {
+ case .hashPrefixSet: .hashPrefixSet(.init(threatKind: threatKind))
+ case .filterSet: .filterSet(.init(threatKind: threatKind))
+ }
+ }
+ }.flatMap { $0 }
+ }
+ }
+}
+
+public extension DataManager.StoredDataType {
+ struct HashPrefixes: MaliciousSiteDataKey {
+ typealias DataSet = HashPrefixSet
+
+ let threatKind: ThreatKind
+
+ var dataType: DataManager.StoredDataType {
+ .hashPrefixSet(self)
+ }
+ }
+}
+extension MaliciousSiteDataKey where Self == DataManager.StoredDataType.HashPrefixes {
+ static func hashPrefixes(threatKind: ThreatKind) -> Self {
+ .init(threatKind: threatKind)
+ }
+}
+
+public extension DataManager.StoredDataType {
+ struct FilterSet: MaliciousSiteDataKey {
+ typealias DataSet = FilterDictionary
+
+ let threatKind: ThreatKind
+
+ var dataType: DataManager.StoredDataType {
+ .filterSet(self)
+ }
+ }
+}
+extension MaliciousSiteDataKey where Self == DataManager.StoredDataType.FilterSet {
+ static func filterSet(threatKind: ThreatKind) -> Self {
+ .init(threatKind: threatKind)
+ }
+}
diff --git a/Sources/MaliciousSiteProtection/Model/ThreatKind.swift b/Sources/MaliciousSiteProtection/Model/ThreatKind.swift
new file mode 100644
index 000000000..bec9e2996
--- /dev/null
+++ b/Sources/MaliciousSiteProtection/Model/ThreatKind.swift
@@ -0,0 +1,27 @@
+//
+// ThreatKind.swift
+//
+// Copyright © 2024 DuckDuckGo. All rights reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//
+
+import Foundation
+
+public enum ThreatKind: String, CaseIterable, Codable, CustomStringConvertible {
+ public var description: String { rawValue }
+
+ case phishing
+ // case malware
+
+}
diff --git a/Sources/MaliciousSiteProtection/Services/DataManager.swift b/Sources/MaliciousSiteProtection/Services/DataManager.swift
new file mode 100644
index 000000000..8e4426dd1
--- /dev/null
+++ b/Sources/MaliciousSiteProtection/Services/DataManager.swift
@@ -0,0 +1,105 @@
+//
+// DataManager.swift
+//
+// Copyright © 2024 DuckDuckGo. All rights reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//
+
+import Foundation
+import os
+
+protocol DataManaging {
+ func dataSet(for key: DataKey) async -> DataKey.DataSet
+ func store(_ dataSet: DataKey.DataSet, for key: DataKey) async
+}
+
+public actor DataManager: DataManaging {
+
+ private let embeddedDataProvider: EmbeddedDataProviding
+ private let fileStore: FileStoring
+
+ public typealias FileNameProvider = (DataManager.StoredDataType) -> String
+ private nonisolated let fileNameProvider: FileNameProvider
+
+ private var store: [StoredDataType: Any] = [:]
+
+ public init(fileStore: FileStoring, embeddedDataProvider: EmbeddedDataProviding, fileNameProvider: @escaping FileNameProvider) {
+ self.embeddedDataProvider = embeddedDataProvider
+ self.fileStore = fileStore
+ self.fileNameProvider = fileNameProvider
+ }
+
+ func dataSet(for key: DataKey) -> DataKey.DataSet {
+ let dataType = key.dataType
+ // return cached dataSet if available
+ if let data = store[key.dataType] as? DataKey.DataSet {
+ return data
+ }
+
+ // read stored dataSet if it‘s newer than the embedded one
+ let dataSet = readStoredDataSet(for: key) ?? {
+ // no stored dataSet or the embedded one is newer
+ let embeddedRevision = embeddedDataProvider.revision(for: dataType)
+ let embeddedItems = embeddedDataProvider.loadDataSet(for: key)
+ return .init(revision: embeddedRevision, items: embeddedItems)
+ }()
+
+ // cache
+ store[dataType] = dataSet
+
+ return dataSet
+ }
+
+ private func readStoredDataSet(for key: DataKey) -> DataKey.DataSet? {
+ let dataType = key.dataType
+ let fileName = fileNameProvider(dataType)
+ guard let data = fileStore.read(from: fileName) else { return nil }
+
+ let storedDataSet: DataKey.DataSet
+ do {
+ storedDataSet = try JSONDecoder().decode(DataKey.DataSet.self, from: data)
+ } catch {
+ Logger.dataManager.error("Error decoding \(fileName): \(error.localizedDescription)")
+ return nil
+ }
+
+ // compare to the embedded data revision
+ let embeddedDataRevision = embeddedDataProvider.revision(for: dataType)
+ guard storedDataSet.revision >= embeddedDataRevision else {
+ Logger.dataManager.error("Stored \(fileName) is outdated: revision: \(storedDataSet.revision), embedded revision: \(embeddedDataRevision).")
+ return nil
+ }
+
+ return storedDataSet
+ }
+
+ func store(_ dataSet: DataKey.DataSet, for key: DataKey) {
+ let dataType = key.dataType
+ let fileName = fileNameProvider(dataType)
+ self.store[dataType] = dataSet
+
+ let data: Data
+ do {
+ data = try JSONEncoder().encode(dataSet)
+ } catch {
+ Logger.dataManager.error("Error encoding \(fileName): \(error.localizedDescription)")
+ assertionFailure("Failed to store data to \(fileName): \(error)")
+ return
+ }
+
+ let success = fileStore.write(data: data, to: fileName)
+ assert(success)
+ }
+
+}
diff --git a/Sources/MaliciousSiteProtection/Services/EmbeddedDataProvider.swift b/Sources/MaliciousSiteProtection/Services/EmbeddedDataProvider.swift
new file mode 100644
index 000000000..942c6214a
--- /dev/null
+++ b/Sources/MaliciousSiteProtection/Services/EmbeddedDataProvider.swift
@@ -0,0 +1,56 @@
+//
+// EmbeddedDataProvider.swift
+//
+// Copyright © 2024 DuckDuckGo. All rights reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//
+
+import Foundation
+import CryptoKit
+
+public protocol EmbeddedDataProviding {
+ func revision(for dataType: DataManager.StoredDataType) -> Int
+ func url(for dataType: DataManager.StoredDataType) -> URL
+ func hash(for dataType: DataManager.StoredDataType) -> String
+
+ func data(withContentsOf url: URL) throws -> Data
+}
+
+extension EmbeddedDataProviding {
+
+ func loadDataSet(for key: DataKey) -> DataKey.EmbeddedDataSet {
+ let dataType = key.dataType
+ let url = url(for: dataType)
+ let data: Data
+ do {
+ data = try self.data(withContentsOf: url)
+#if DEBUG
+ assert(data.sha256 == hash(for: dataType), "SHA mismatch for \(url.path)")
+#endif
+ } catch {
+ fatalError("\(self): Could not load embedded data set at “\(url)”: \(error)")
+ }
+ do {
+ let result = try JSONDecoder().decode(DataKey.EmbeddedDataSet.self, from: data)
+ return result
+ } catch {
+ fatalError("\(self): Could not decode embedded data set at “\(url)”: \(error)")
+ }
+ }
+
+ public func data(withContentsOf url: URL) throws -> Data {
+ try Data(contentsOf: url)
+ }
+
+}
diff --git a/Sources/MaliciousSiteProtection/Services/FileStore.swift b/Sources/MaliciousSiteProtection/Services/FileStore.swift
new file mode 100644
index 000000000..06418e6a2
--- /dev/null
+++ b/Sources/MaliciousSiteProtection/Services/FileStore.swift
@@ -0,0 +1,67 @@
+//
+// FileStore.swift
+//
+// Copyright © 2024 DuckDuckGo. All rights reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//
+
+import Foundation
+import os
+
+public protocol FileStoring {
+ @discardableResult func write(data: Data, to filename: String) -> Bool
+ func read(from filename: String) -> Data?
+}
+
+public struct FileStore: FileStoring, CustomDebugStringConvertible {
+ private let dataStoreURL: URL
+
+ public init(dataStoreURL: URL) {
+ self.dataStoreURL = dataStoreURL
+ createDirectoryIfNeeded()
+ }
+
+ private func createDirectoryIfNeeded() {
+ do {
+ try FileManager.default.createDirectory(at: dataStoreURL, withIntermediateDirectories: true, attributes: nil)
+ } catch {
+ Logger.dataManager.error("Failed to create directory: \(error.localizedDescription)")
+ }
+ }
+
+ public func write(data: Data, to filename: String) -> Bool {
+ let fileURL = dataStoreURL.appendingPathComponent(filename)
+ do {
+ try data.write(to: fileURL)
+ return true
+ } catch {
+ Logger.dataManager.error("Error writing to directory: \(error.localizedDescription)")
+ return false
+ }
+ }
+
+ public func read(from filename: String) -> Data? {
+ let fileURL = dataStoreURL.appendingPathComponent(filename)
+ do {
+ return try Data(contentsOf: fileURL)
+ } catch {
+ Logger.dataManager.error("Error accessing application support directory: \(error)")
+ return nil
+ }
+ }
+
+ public var debugDescription: String {
+ return "<\(type(of: self)) - \"\(dataStoreURL.path)\">"
+ }
+}
diff --git a/Sources/MaliciousSiteProtection/Services/UpdateManager.swift b/Sources/MaliciousSiteProtection/Services/UpdateManager.swift
new file mode 100644
index 000000000..57394edbf
--- /dev/null
+++ b/Sources/MaliciousSiteProtection/Services/UpdateManager.swift
@@ -0,0 +1,101 @@
+//
+// UpdateManager.swift
+//
+// Copyright © 2024 DuckDuckGo. All rights reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//
+
+import Common
+import Foundation
+import Networking
+import os
+
+protocol UpdateManaging {
+ func updateData(for key: some MaliciousSiteDataKey) async
+
+ func startPeriodicUpdates() -> Task
+}
+
+public struct UpdateManager: UpdateManaging {
+
+ private let apiClient: APIClient.Mockable
+ private let dataManager: DataManaging
+
+ public typealias UpdateIntervalProvider = (DataManager.StoredDataType) -> TimeInterval?
+ private let updateIntervalProvider: UpdateIntervalProvider
+ private let sleeper: Sleeper
+
+ public init(apiEnvironment: APIClientEnvironment, service: APIService = DefaultAPIService(urlSession: .shared), dataManager: DataManager, updateIntervalProvider: @escaping UpdateIntervalProvider) {
+ self.init(apiClient: APIClient(environment: apiEnvironment, service: service), dataManager: dataManager, updateIntervalProvider: updateIntervalProvider)
+ }
+
+ init(apiClient: APIClient.Mockable, dataManager: DataManaging, sleeper: Sleeper = .default, updateIntervalProvider: @escaping UpdateIntervalProvider) {
+ self.apiClient = apiClient
+ self.dataManager = dataManager
+ self.updateIntervalProvider = updateIntervalProvider
+ self.sleeper = sleeper
+ }
+
+ func updateData(for key: DataKey) async {
+ // load currently stored data set
+ var dataSet = await dataManager.dataSet(for: key)
+ let oldRevision = dataSet.revision
+
+ // get change set from current revision from API
+ let changeSet: APIClient.ChangeSetResponse
+ do {
+ let request = DataKey.DataSet.APIRequest(threatKind: key.threatKind, revision: oldRevision)
+ changeSet = try await apiClient.load(request)
+ } catch {
+ Logger.updateManager.error("error fetching filter set: \(error)")
+ return
+ }
+ guard !changeSet.isEmpty || changeSet.revision != dataSet.revision else {
+ Logger.updateManager.debug("no changes to filter set")
+ return
+ }
+
+ // apply changes
+ dataSet.apply(changeSet)
+
+ // store back
+ await self.dataManager.store(dataSet, for: key)
+ Logger.updateManager.debug("\(type(of: key)).\(key.threatKind) updated from rev.\(oldRevision) to rev.\(dataSet.revision)")
+ }
+
+ public func startPeriodicUpdates() -> Task {
+ Task.detached {
+ // run update jobs in background for every data type
+ try await withThrowingTaskGroup(of: Never.self) { group in
+ for dataType in DataManager.StoredDataType.allCases {
+ // get update interval from provider
+ guard let updateInterval = updateIntervalProvider(dataType) else { continue }
+ guard updateInterval > 0 else {
+ assertionFailure("Update interval for \(dataType) must be positive")
+ continue
+ }
+
+ group.addTask {
+ // run periodically until the parent task is cancelled
+ try await performPeriodicJob(interval: updateInterval, sleeper: sleeper) {
+ await self.updateData(for: dataType.dataKey)
+ }
+ }
+ }
+ for try await _ in group {}
+ }
+ }
+ }
+
+}
diff --git a/Sources/Navigation/Extensions/WKErrorExtension.swift b/Sources/Navigation/Extensions/WKErrorExtension.swift
index f1a5c238d..de750e766 100644
--- a/Sources/Navigation/Extensions/WKErrorExtension.swift
+++ b/Sources/Navigation/Extensions/WKErrorExtension.swift
@@ -33,6 +33,14 @@ extension WKError {
code.rawValue == NSURLErrorCancelled && _nsError.domain == NSURLErrorDomain
}
+ public var isServerCertificateUntrusted: Bool {
+ _nsError.isServerCertificateUntrusted
+ }
+}
+extension NSError {
+ public var isServerCertificateUntrusted: Bool {
+ code == NSURLErrorServerCertificateUntrusted && domain == NSURLErrorDomain
+ }
}
extension WKError {
diff --git a/Sources/NetworkProtection/PacketTunnelProvider.swift b/Sources/NetworkProtection/PacketTunnelProvider.swift
index a197c8a9a..7e82291d7 100644
--- a/Sources/NetworkProtection/PacketTunnelProvider.swift
+++ b/Sources/NetworkProtection/PacketTunnelProvider.swift
@@ -42,7 +42,6 @@ open class PacketTunnelProvider: NEPacketTunnelProvider {
case rekeyAttempt(_ step: RekeyAttemptStep)
case failureRecoveryAttempt(_ step: FailureRecoveryStep)
case serverMigrationAttempt(_ step: ServerMigrationAttemptStep)
- case malformedErrorDetected(_ error: Error)
}
public enum AttemptStep: CustomDebugStringConvertible {
@@ -705,7 +704,6 @@ open class PacketTunnelProvider: NEPacketTunnelProvider {
// Check that the error is valid and able to be re-thrown to the OS before shutting the tunnel down
if let wrappedError = wrapped(error: error) {
// Wait for the provider to complete its pixel request.
- providerEvents.fire(.malformedErrorDetected(error))
try? await Task.sleep(interval: .seconds(2))
throw wrappedError
} else {
@@ -743,7 +741,6 @@ open class PacketTunnelProvider: NEPacketTunnelProvider {
// Check that the error is valid and able to be re-thrown to the OS before shutting the tunnel down
if let wrappedError = wrapped(error: error) {
// Wait for the provider to complete its pixel request.
- providerEvents.fire(.malformedErrorDetected(error))
try? await Task.sleep(interval: .seconds(2))
throw wrappedError
} else {
@@ -887,6 +884,7 @@ open class PacketTunnelProvider: NEPacketTunnelProvider {
@MainActor
private func stopTunnel() async throws {
connectionStatus = .disconnecting
+
await stopMonitors()
try await stopAdapter()
}
@@ -895,10 +893,6 @@ open class PacketTunnelProvider: NEPacketTunnelProvider {
private func stopAdapter() async throws {
try await withCheckedThrowingContinuation { (continuation: CheckedContinuation) in
adapter.stop { [weak self] error in
- if let self {
- self.handleAdapterStopped()
- }
-
if let error {
self?.debugEvents.fire(error.networkProtectionError)
@@ -1420,11 +1414,6 @@ open class PacketTunnelProvider: NEPacketTunnelProvider {
try await startMonitors(testImmediately: testImmediately)
}
- @MainActor
- public func handleAdapterStopped() {
- connectionStatus = .disconnected
- }
-
// MARK: - Monitors
private func startTunnelFailureMonitor() async {
diff --git a/Sources/Networking/README.md b/Sources/Networking/README.md
index 83a2c5ce3..751ee63d8 100644
--- a/Sources/Networking/README.md
+++ b/Sources/Networking/README.md
@@ -19,7 +19,7 @@ let request = APIRequestV2(url: HTTPURLResponse.testUrl,
responseConstraints: [.allowHTTPNotModified,
.requireETagHeader,
.requireUserAgent],
- allowedQueryReservedCharacters: CharacterSet(charactersIn: ","))!
+ allowedQueryReservedCharacters: CharacterSet(charactersIn: ","))
let apiService = DefaultAPIService(urlSession: URLSession.shared)
```
@@ -55,12 +55,12 @@ The `MockPIService` implementing `APIService` can be found in `BSK/TestUtils`
```
let apiResponse = (Data(), HTTPURLResponse(url: HTTPURLResponse.testUrl,
- statusCode: 200,
- httpVersion: nil,
- headerFields: nil)!)
-let mockedAPIService = MockAPIService(decodableResponse: Result.failure(SomeError.testError), apiResponse: Result.success(apiResponse) )
+ statusCode: 200,
+ httpVersion: nil,
+ headerFields: nil)!)
+let mockedAPIService = MockAPIService(apiResponse: Result.success(apiResponse))
```
## v1 (Legacy)
-Not to be used. All V1 public functions have been deprecated and maintained only for backward compatibility.
\ No newline at end of file
+Not to be used. All V1 public functions have been deprecated and maintained only for backward compatibility.
diff --git a/Sources/Networking/v1/APIHeaders.swift b/Sources/Networking/v1/APIHeaders.swift
index 6d7f0a4b0..a5786c949 100644
--- a/Sources/Networking/v1/APIHeaders.swift
+++ b/Sources/Networking/v1/APIHeaders.swift
@@ -25,7 +25,7 @@ public extension APIRequest {
struct Headers {
public typealias UserAgent = String
- private static var userAgent: UserAgent?
+ public private(set) static var userAgent: UserAgent?
public static func setUserAgent(_ userAgent: UserAgent) {
self.userAgent = userAgent
}
diff --git a/Sources/Networking/v2/APIRequestV2.swift b/Sources/Networking/v2/APIRequestV2.swift
index 07434de67..a61604861 100644
--- a/Sources/Networking/v2/APIRequestV2.swift
+++ b/Sources/Networking/v2/APIRequestV2.swift
@@ -16,12 +16,11 @@
// limitations under the License.
//
+import Common
import Foundation
public struct APIRequestV2: CustomDebugStringConvertible {
- public typealias QueryItems = [String: String]
-
let timeoutInterval: TimeInterval
let responseConstraints: [APIResponseConstraints]?
public let urlRequest: URLRequest
@@ -37,25 +36,25 @@ public struct APIRequestV2: CustomDebugStringConvertible {
/// - cachePolicy: The request cache policy, default is `.useProtocolCachePolicy`
/// - responseRequirements: The response requirements
/// - allowedQueryReservedCharacters: The characters in this character set will not be URL encoded in the query parameters
- public init?(url: URL,
- method: HTTPRequestMethod = .get,
- queryItems: QueryItems? = nil,
- headers: APIRequestV2.HeadersV2? = APIRequestV2.HeadersV2(),
- body: Data? = nil,
- timeoutInterval: TimeInterval = 60.0,
- cachePolicy: URLRequest.CachePolicy? = nil,
- responseConstraints: [APIResponseConstraints]? = nil,
- allowedQueryReservedCharacters: CharacterSet? = nil) {
+ public init(
+ url: URL,
+ method: HTTPRequestMethod = .get,
+ queryItems: QueryParams?,
+ headers: APIRequestV2.HeadersV2? = APIRequestV2.HeadersV2(),
+ body: Data? = nil,
+ timeoutInterval: TimeInterval = 60.0,
+ cachePolicy: URLRequest.CachePolicy? = nil,
+ responseConstraints: [APIResponseConstraints]? = nil,
+ allowedQueryReservedCharacters: CharacterSet? = nil
+ ) where QueryParams.Element == (key: String, value: String) {
+
self.timeoutInterval = timeoutInterval
self.responseConstraints = responseConstraints
- // Generate URL request
- guard var urlComps = URLComponents(url: url, resolvingAgainstBaseURL: true) else {
- return nil
- }
- urlComps.queryItems = queryItems?.toURLQueryItems(allowedReservedCharacters: allowedQueryReservedCharacters)
- guard let finalURL = urlComps.url else {
- return nil
+ let finalURL = if let queryItems {
+ url.appendingParameters(queryItems, allowedReservedCharacters: allowedQueryReservedCharacters)
+ } else {
+ url
}
var request = URLRequest(url: finalURL, timeoutInterval: timeoutInterval)
request.allHTTPHeaderFields = headers?.httpHeaders
@@ -67,6 +66,19 @@ public struct APIRequestV2: CustomDebugStringConvertible {
self.urlRequest = request
}
+ public init(
+ url: URL,
+ method: HTTPRequestMethod = .get,
+ headers: APIRequestV2.HeadersV2? = APIRequestV2.HeadersV2(),
+ body: Data? = nil,
+ timeoutInterval: TimeInterval = 60.0,
+ cachePolicy: URLRequest.CachePolicy? = nil,
+ responseConstraints: [APIResponseConstraints]? = nil,
+ allowedQueryReservedCharacters: CharacterSet? = nil
+ ) {
+ self.init(url: url, method: method, queryItems: [String: String]?.none, headers: headers, body: body, timeoutInterval: timeoutInterval, cachePolicy: cachePolicy, responseConstraints: responseConstraints, allowedQueryReservedCharacters: allowedQueryReservedCharacters)
+ }
+
public var debugDescription: String {
"""
APIRequestV2:
diff --git a/Sources/Networking/v2/APIResponseV2.swift b/Sources/Networking/v2/APIResponseV2.swift
index 1b178fd93..8987e377b 100644
--- a/Sources/Networking/v2/APIResponseV2.swift
+++ b/Sources/Networking/v2/APIResponseV2.swift
@@ -20,8 +20,13 @@ import Foundation
import os.log
public struct APIResponseV2 {
- let data: Data?
- let httpResponse: HTTPURLResponse
+ public let data: Data?
+ public let httpResponse: HTTPURLResponse
+
+ public init(data: Data?, httpResponse: HTTPURLResponse) {
+ self.data = data
+ self.httpResponse = httpResponse
+ }
}
public extension APIResponseV2 {
diff --git a/Sources/Onboarding/ContextualDaxDialogs/OnboardingSuggestionsViewModel.swift b/Sources/Onboarding/ContextualDaxDialogs/OnboardingSuggestionsViewModel.swift
index d10fecd56..ffff91188 100644
--- a/Sources/Onboarding/ContextualDaxDialogs/OnboardingSuggestionsViewModel.swift
+++ b/Sources/Onboarding/ContextualDaxDialogs/OnboardingSuggestionsViewModel.swift
@@ -19,8 +19,8 @@
import Foundation
public protocol OnboardingNavigationDelegate: AnyObject {
- func searchFor(_ query: String)
- func navigateTo(url: URL)
+ func searchFromOnboarding(for query: String)
+ func navigateFromOnboarding(to url: URL)
}
public protocol OnboardingSearchSuggestionsPixelReporting {
@@ -52,7 +52,7 @@ public struct OnboardingSearchSuggestionsViewModel {
public func listItemPressed(_ item: ContextualOnboardingListItem) {
pixelReporter.trackSearchSuggetionOptionTapped()
- delegate?.searchFor(item.title)
+ delegate?.searchFromOnboarding(for: item.title)
}
}
@@ -82,6 +82,6 @@ public struct OnboardingSiteSuggestionsViewModel {
public func listItemPressed(_ item: ContextualOnboardingListItem) {
guard let url = URL(string: item.title) else { return }
pixelReporter.trackSiteSuggetionOptionTapped()
- delegate?.navigateTo(url: url)
+ delegate?.navigateFromOnboarding(to: url)
}
}
diff --git a/Sources/PhishingDetection/Logger+PhishingDetection.swift b/Sources/PhishingDetection/Logger+PhishingDetection.swift
deleted file mode 100644
index 96a606772..000000000
--- a/Sources/PhishingDetection/Logger+PhishingDetection.swift
+++ /dev/null
@@ -1,29 +0,0 @@
-//
-// Logger+PhishingDetection.swift
-//
-// Copyright © 2023 DuckDuckGo. All rights reserved.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-//
-
-import Foundation
-import os
-
-public extension Logger {
- static var phishingDetection: Logger = { Logger(subsystem: "Phishing Detection", category: "") }()
- static var phishingDetectionClient: Logger = { Logger(subsystem: "Phishing Detection", category: "APIClient") }()
- static var phishingDetectionTasks: Logger = { Logger(subsystem: "Phishing Detection", category: "BackgroundActivities") }()
- static var phishingDetectionDataProvider: Logger = { Logger(subsystem: "Phishing Detection", category: "DataProvider") }()
- static var phishingDetectionDataStore: Logger = { Logger(subsystem: "Phishing Detection", category: "DataStore") }()
- static var phishingDetectionUpdateManager: Logger = { Logger(subsystem: "Phishing Detection", category: "UpdateManager") }()
-}
diff --git a/Sources/PhishingDetection/PhishingDetectionClient.swift b/Sources/PhishingDetection/PhishingDetectionClient.swift
deleted file mode 100644
index 942075b71..000000000
--- a/Sources/PhishingDetection/PhishingDetectionClient.swift
+++ /dev/null
@@ -1,177 +0,0 @@
-//
-// PhishingDetectionClient.swift
-//
-// Copyright © 2023 DuckDuckGo. All rights reserved.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-//
-
-import Foundation
-import Common
-import os
-
-public struct HashPrefixResponse: Codable, Equatable {
- public var insert: [String]
- public var delete: [String]
- public var revision: Int
- public var replace: Bool
-
- public init(insert: [String], delete: [String], revision: Int, replace: Bool) {
- self.insert = insert
- self.delete = delete
- self.revision = revision
- self.replace = replace
- }
-}
-
-public struct FilterSetResponse: Codable, Equatable {
- public var insert: [Filter]
- public var delete: [Filter]
- public var revision: Int
- public var replace: Bool
-
- public init(insert: [Filter], delete: [Filter], revision: Int, replace: Bool) {
- self.insert = insert
- self.delete = delete
- self.revision = revision
- self.replace = replace
- }
-}
-
-public struct MatchResponse: Codable, Equatable {
- public var matches: [Match]
-}
-
-public protocol PhishingDetectionClientProtocol {
- func getFilterSet(revision: Int) async -> FilterSetResponse
- func getHashPrefixes(revision: Int) async -> HashPrefixResponse
- func getMatches(hashPrefix: String) async -> [Match]
-}
-
-public protocol URLSessionProtocol {
- func data(for request: URLRequest) async throws -> (Data, URLResponse)
-}
-
-extension URLSession: URLSessionProtocol {}
-
-extension URLSessionProtocol {
- public static var defaultSession: URLSessionProtocol {
- return URLSession.shared
- }
-}
-
-public class PhishingDetectionAPIClient: PhishingDetectionClientProtocol {
-
- public enum Environment {
- case production
- case staging
- }
-
- enum Constants {
- static let productionEndpoint = URL(string: "https://duckduckgo.com/api/protection/")!
- static let stagingEndpoint = URL(string: "https://staging.duckduckgo.com/api/protection/")!
- enum APIPath: String {
- case filterSet
- case hashPrefix
- case matches
- }
- }
-
- private let endpointURL: URL
- private let session: URLSessionProtocol!
- private var headers: [String: String]? = [:]
-
- var filterSetURL: URL {
- endpointURL.appendingPathComponent(Constants.APIPath.filterSet.rawValue)
- }
-
- var hashPrefixURL: URL {
- endpointURL.appendingPathComponent(Constants.APIPath.hashPrefix.rawValue)
- }
-
- var matchesURL: URL {
- endpointURL.appendingPathComponent(Constants.APIPath.matches.rawValue)
- }
-
- public init(environment: Environment = .production, session: URLSessionProtocol = URLSession.defaultSession) {
- switch environment {
- case .production:
- endpointURL = Constants.productionEndpoint
- case .staging:
- endpointURL = Constants.stagingEndpoint
- }
- self.session = session
- }
-
- public func getFilterSet(revision: Int) async -> FilterSetResponse {
- guard let url = createURL(for: .filterSet, revision: revision) else {
- logDebug("🔸 Invalid filterSet revision URL: \(revision)")
- return FilterSetResponse(insert: [], delete: [], revision: revision, replace: false)
- }
- return await fetch(url: url, responseType: FilterSetResponse.self) ?? FilterSetResponse(insert: [], delete: [], revision: revision, replace: false)
- }
-
- public func getHashPrefixes(revision: Int) async -> HashPrefixResponse {
- guard let url = createURL(for: .hashPrefix, revision: revision) else {
- logDebug("🔸 Invalid hashPrefix revision URL: \(revision)")
- return HashPrefixResponse(insert: [], delete: [], revision: revision, replace: false)
- }
- return await fetch(url: url, responseType: HashPrefixResponse.self) ?? HashPrefixResponse(insert: [], delete: [], revision: revision, replace: false)
- }
-
- public func getMatches(hashPrefix: String) async -> [Match] {
- let queryItems = [URLQueryItem(name: "hashPrefix", value: hashPrefix)]
- guard let url = createURL(for: .matches, queryItems: queryItems) else {
- logDebug("🔸 Invalid matches URL: \(hashPrefix)")
- return []
- }
- return await fetch(url: url, responseType: MatchResponse.self)?.matches ?? []
- }
-}
-
-// MARK: Private Methods
-extension PhishingDetectionAPIClient {
-
- private func logDebug(_ message: String) {
- Logger.phishingDetectionClient.debug("\(message)")
- }
-
- private func createURL(for path: Constants.APIPath, revision: Int? = nil, queryItems: [URLQueryItem]? = nil) -> URL? {
- // Start with the base URL and append the path component
- var urlComponents = URLComponents(url: endpointURL.appendingPathComponent(path.rawValue), resolvingAgainstBaseURL: true)
- var items = queryItems ?? []
- if let revision = revision, revision > 0 {
- items.append(URLQueryItem(name: "revision", value: String(revision)))
- }
- urlComponents?.queryItems = items.isEmpty ? nil : items
- return urlComponents?.url
- }
-
- private func fetch(url: URL, responseType: T.Type) async -> T? {
- var request = URLRequest(url: url)
- request.httpMethod = "GET"
- request.allHTTPHeaderFields = headers
-
- do {
- let (data, _) = try await session.data(for: request)
- if let response = try? JSONDecoder().decode(responseType, from: data) {
- return response
- } else {
- logDebug("🔸 Failed to decode response for \(String(describing: responseType)): \(data)")
- }
- } catch {
- logDebug("🔴 Failed to load \(String(describing: responseType)) data: \(error)")
- }
- return nil
- }
-}
diff --git a/Sources/PhishingDetection/PhishingDetectionDataActivities.swift b/Sources/PhishingDetection/PhishingDetectionDataActivities.swift
deleted file mode 100644
index 3f195d75e..000000000
--- a/Sources/PhishingDetection/PhishingDetectionDataActivities.swift
+++ /dev/null
@@ -1,110 +0,0 @@
-//
-// PhishingDetectionDataActivities.swift
-//
-// Copyright © 2024 DuckDuckGo. All rights reserved.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-//
-
-import Foundation
-import Common
-import os
-
-public protocol BackgroundActivityScheduling: Actor {
- func start()
- func stop()
-}
-
-actor BackgroundActivityScheduler: BackgroundActivityScheduling {
-
- private var task: Task?
- private var timer: Timer?
- private let interval: TimeInterval
- private let identifier: String
- private let activity: () async -> Void
-
- init(interval: TimeInterval, identifier: String, activity: @escaping () async -> Void) {
- self.interval = interval
- self.identifier = identifier
- self.activity = activity
- }
-
- func start() {
- stop()
- task = Task {
- let taskId = UUID().uuidString
- while !Task.isCancelled {
- await activity()
- do {
- Logger.phishingDetectionTasks.debug("🟢 \(self.identifier) task was executed in instance \(taskId)")
- try await Task.sleep(nanoseconds: UInt64(interval * 1_000_000_000))
- } catch {
- Logger.phishingDetectionTasks.error("🔴 Error \(self.identifier) task was cancelled before it could finish sleeping.")
- break
- }
- }
- }
- }
-
- func stop() {
- task?.cancel()
- task = nil
- }
-}
-
-public protocol PhishingDetectionDataActivityHandling {
- func start()
- func stop()
-}
-
-public class PhishingDetectionDataActivities: PhishingDetectionDataActivityHandling {
- private var schedulers: [BackgroundActivityScheduler]
- private var running: Bool = false
-
- var dataProvider: PhishingDetectionDataProviding
-
- public init(hashPrefixInterval: TimeInterval = 20 * 60, filterSetInterval: TimeInterval = 12 * 60 * 60, phishingDetectionDataProvider: PhishingDetectionDataProviding, updateManager: PhishingDetectionUpdateManaging) {
- let hashPrefixScheduler = BackgroundActivityScheduler(
- interval: hashPrefixInterval,
- identifier: "hashPrefixes.update",
- activity: { await updateManager.updateHashPrefixes() }
- )
- let filterSetScheduler = BackgroundActivityScheduler(
- interval: filterSetInterval,
- identifier: "filterSet.update",
- activity: { await updateManager.updateFilterSet() }
- )
- self.schedulers = [hashPrefixScheduler, filterSetScheduler]
- self.dataProvider = phishingDetectionDataProvider
- }
-
- public func start() {
- if !running {
- Task {
- for scheduler in schedulers {
- await scheduler.start()
- }
- }
- running = true
- }
- }
-
- public func stop() {
- Task {
- for scheduler in schedulers {
- await scheduler.stop()
- }
- }
- running = false
- }
-}
diff --git a/Sources/PhishingDetection/PhishingDetectionDataProvider.swift b/Sources/PhishingDetection/PhishingDetectionDataProvider.swift
deleted file mode 100644
index af1c87672..000000000
--- a/Sources/PhishingDetection/PhishingDetectionDataProvider.swift
+++ /dev/null
@@ -1,75 +0,0 @@
-//
-// PhishingDetectionDataProvider.swift
-//
-// Copyright © 2024 DuckDuckGo. All rights reserved.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-//
-
-import Foundation
-import CryptoKit
-import Common
-import os
-
-public protocol PhishingDetectionDataProviding {
- var embeddedRevision: Int { get }
- func loadEmbeddedFilterSet() -> Set
- func loadEmbeddedHashPrefixes() -> Set
-}
-
-public class PhishingDetectionDataProvider: PhishingDetectionDataProviding {
- public private(set) var embeddedRevision: Int
- var embeddedFilterSetURL: URL
- var embeddedFilterSetDataSHA: String
- var embeddedHashPrefixURL: URL
- var embeddedHashPrefixDataSHA: String
-
- public init(revision: Int, filterSetURL: URL, filterSetDataSHA: String, hashPrefixURL: URL, hashPrefixDataSHA: String) {
- embeddedFilterSetURL = filterSetURL
- embeddedFilterSetDataSHA = filterSetDataSHA
- embeddedHashPrefixURL = hashPrefixURL
- embeddedHashPrefixDataSHA = hashPrefixDataSHA
- embeddedRevision = revision
- }
-
- private func loadData(from url: URL, expectedSHA: String) throws -> Data {
- let data = try Data(contentsOf: url)
- let sha256 = SHA256.hash(data: data)
- let hashString = sha256.compactMap { String(format: "%02x", $0) }.joined()
-
- guard hashString == expectedSHA else {
- throw NSError(domain: "PhishingDetectionDataProvider", code: 1001, userInfo: [NSLocalizedDescriptionKey: "SHA mismatch"])
- }
- return data
- }
-
- public func loadEmbeddedFilterSet() -> Set {
- do {
- let filterSetData = try loadData(from: embeddedFilterSetURL, expectedSHA: embeddedFilterSetDataSHA)
- return try JSONDecoder().decode(Set.self, from: filterSetData)
- } catch {
- Logger.phishingDetectionDataProvider.error("🔴 Error: SHA mismatch for filterSet JSON file. Expected \(self.embeddedFilterSetDataSHA)")
- return []
- }
- }
-
- public func loadEmbeddedHashPrefixes() -> Set {
- do {
- let hashPrefixData = try loadData(from: embeddedHashPrefixURL, expectedSHA: embeddedHashPrefixDataSHA)
- return try JSONDecoder().decode(Set.self, from: hashPrefixData)
- } catch {
- Logger.phishingDetectionDataProvider.error("🔴 Error: SHA mismatch for hashPrefixes JSON file. Expected \(self.embeddedHashPrefixDataSHA)")
- return []
- }
- }
-}
diff --git a/Sources/PhishingDetection/PhishingDetectionDataStore.swift b/Sources/PhishingDetection/PhishingDetectionDataStore.swift
deleted file mode 100644
index f247f90b8..000000000
--- a/Sources/PhishingDetection/PhishingDetectionDataStore.swift
+++ /dev/null
@@ -1,266 +0,0 @@
-//
-// PhishingDetectionDataStore.swift
-//
-// Copyright © 2024 DuckDuckGo. All rights reserved.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-//
-
-import Foundation
-import Common
-import os
-
-enum PhishingDetectionDataError: Error {
- case empty
-}
-
-public struct Filter: Codable, Hashable {
- public var hashValue: String
- public var regex: String
-
- enum CodingKeys: String, CodingKey {
- case hashValue = "hash"
- case regex
- }
-
- public init(hashValue: String, regex: String) {
- self.hashValue = hashValue
- self.regex = regex
- }
-}
-
-public struct Match: Codable, Hashable {
- var hostname: String
- var url: String
- var regex: String
- var hash: String
-
- public init(hostname: String, url: String, regex: String, hash: String) {
- self.hostname = hostname
- self.url = url
- self.regex = regex
- self.hash = hash
- }
-}
-
-public protocol PhishingDetectionDataSaving {
- var filterSet: Set { get }
- var hashPrefixes: Set { get }
- var currentRevision: Int { get }
- func saveFilterSet(set: Set)
- func saveHashPrefixes(set: Set)
- func saveRevision(_ revision: Int)
-}
-
-public class PhishingDetectionDataStore: PhishingDetectionDataSaving {
- private lazy var _filterSet: Set = {
- loadFilterSet()
- }()
-
- private lazy var _hashPrefixes: Set = {
- loadHashPrefix()
- }()
-
- private lazy var _currentRevision: Int = {
- loadRevision()
- }()
-
- public private(set) var filterSet: Set {
- get { _filterSet }
- set { _filterSet = newValue }
- }
- public private(set) var hashPrefixes: Set {
- get { _hashPrefixes }
- set { _hashPrefixes = newValue }
- }
- public private(set) var currentRevision: Int {
- get { _currentRevision }
- set { _currentRevision = newValue }
- }
-
- private let dataProvider: PhishingDetectionDataProviding
- private let fileStorageManager: FileStorageManager
- private let encoder = JSONEncoder()
- private let revisionFilename = "revision.txt"
- private let hashPrefixFilename = "hashPrefixes.json"
- private let filterSetFilename = "filterSet.json"
-
- public init(dataProvider: PhishingDetectionDataProviding,
- fileStorageManager: FileStorageManager? = nil) {
- self.dataProvider = dataProvider
- if let injectedFileStorageManager = fileStorageManager {
- self.fileStorageManager = injectedFileStorageManager
- } else {
- self.fileStorageManager = PhishingFileStorageManager()
- }
- }
-
- private func writeHashPrefixes() {
- let encoder = JSONEncoder()
- do {
- let hashPrefixesData = try encoder.encode(Array(hashPrefixes))
- fileStorageManager.write(data: hashPrefixesData, to: hashPrefixFilename)
- } catch {
- Logger.phishingDetectionDataStore.error("Error saving hash prefixes data: \(error.localizedDescription)")
- }
- }
-
- private func writeFilterSet() {
- let encoder = JSONEncoder()
- do {
- let filterSetData = try encoder.encode(Array(filterSet))
- fileStorageManager.write(data: filterSetData, to: filterSetFilename)
- } catch {
- Logger.phishingDetectionDataStore.error("Error saving filter set data: \(error.localizedDescription)")
- }
- }
-
- private func writeRevision() {
- let encoder = JSONEncoder()
- do {
- let revisionData = try encoder.encode(currentRevision)
- fileStorageManager.write(data: revisionData, to: revisionFilename)
- } catch {
- Logger.phishingDetectionDataStore.error("Error saving revision data: \(error.localizedDescription)")
- }
- }
-
- private func loadHashPrefix() -> Set {
- guard let data = fileStorageManager.read(from: hashPrefixFilename) else {
- return dataProvider.loadEmbeddedHashPrefixes()
- }
- let decoder = JSONDecoder()
- do {
- if loadRevisionFromDisk() < dataProvider.embeddedRevision {
- return dataProvider.loadEmbeddedHashPrefixes()
- }
- let onDiskHashPrefixes = Set(try decoder.decode(Set.self, from: data))
- return onDiskHashPrefixes
- } catch {
- Logger.phishingDetectionDataStore.error("Error decoding \(self.hashPrefixFilename): \(error.localizedDescription)")
- return dataProvider.loadEmbeddedHashPrefixes()
- }
- }
-
- private func loadFilterSet() -> Set {
- guard let data = fileStorageManager.read(from: filterSetFilename) else {
- return dataProvider.loadEmbeddedFilterSet()
- }
- let decoder = JSONDecoder()
- do {
- if loadRevisionFromDisk() < dataProvider.embeddedRevision {
- return dataProvider.loadEmbeddedFilterSet()
- }
- let onDiskFilterSet = Set(try decoder.decode(Set.self, from: data))
- return onDiskFilterSet
- } catch {
- Logger.phishingDetectionDataStore.error("Error decoding \(self.filterSetFilename): \(error.localizedDescription)")
- return dataProvider.loadEmbeddedFilterSet()
- }
- }
-
- private func loadRevisionFromDisk() -> Int {
- guard let data = fileStorageManager.read(from: revisionFilename) else {
- return dataProvider.embeddedRevision
- }
- let decoder = JSONDecoder()
- do {
- return try decoder.decode(Int.self, from: data)
- } catch {
- Logger.phishingDetectionDataStore.error("Error decoding \(self.revisionFilename): \(error.localizedDescription)")
- return dataProvider.embeddedRevision
- }
- }
-
- private func loadRevision() -> Int {
- guard let data = fileStorageManager.read(from: revisionFilename) else {
- return dataProvider.embeddedRevision
- }
- let decoder = JSONDecoder()
- do {
- let loadedRevision = try decoder.decode(Int.self, from: data)
- if loadedRevision < dataProvider.embeddedRevision {
- return dataProvider.embeddedRevision
- }
- return loadedRevision
- } catch {
- Logger.phishingDetectionDataStore.error("Error decoding \(self.revisionFilename): \(error.localizedDescription)")
- return dataProvider.embeddedRevision
- }
- }
-}
-
-extension PhishingDetectionDataStore {
- public func saveFilterSet(set: Set) {
- self.filterSet = set
- writeFilterSet()
- }
-
- public func saveHashPrefixes(set: Set) {
- self.hashPrefixes = set
- writeHashPrefixes()
- }
-
- public func saveRevision(_ revision: Int) {
- self.currentRevision = revision
- writeRevision()
- }
-}
-
-public protocol FileStorageManager {
- func write(data: Data, to filename: String)
- func read(from filename: String) -> Data?
-}
-
-final class PhishingFileStorageManager: FileStorageManager {
- private let dataStoreURL: URL
-
- init() {
- let dataStoreDirectory: URL
- do {
- dataStoreDirectory = try FileManager.default.url(for: .applicationSupportDirectory, in: .userDomainMask, appropriateFor: nil, create: true)
- } catch {
- Logger.phishingDetectionDataStore.error("Error accessing application support directory: \(error.localizedDescription)")
- dataStoreDirectory = FileManager.default.temporaryDirectory
- }
- dataStoreURL = dataStoreDirectory.appendingPathComponent(Bundle.main.bundleIdentifier!, isDirectory: true)
- createDirectoryIfNeeded()
- }
-
- private func createDirectoryIfNeeded() {
- do {
- try FileManager.default.createDirectory(at: dataStoreURL, withIntermediateDirectories: true, attributes: nil)
- } catch {
- Logger.phishingDetectionDataStore.error("Failed to create directory: \(error.localizedDescription)")
- }
- }
-
- func write(data: Data, to filename: String) {
- let fileURL = dataStoreURL.appendingPathComponent(filename)
- do {
- try data.write(to: fileURL)
- } catch {
- Logger.phishingDetectionDataStore.error("Error writing to directory: \(error.localizedDescription)")
- }
- }
-
- func read(from filename: String) -> Data? {
- let fileURL = dataStoreURL.appendingPathComponent(filename)
- do {
- return try Data(contentsOf: fileURL)
- } catch {
- Logger.phishingDetectionDataStore.error("Error accessing application support directory: \(error)")
- return nil
- }
- }
-}
diff --git a/Sources/PhishingDetection/PhishingDetectionUpdateManager.swift b/Sources/PhishingDetection/PhishingDetectionUpdateManager.swift
deleted file mode 100644
index b811082e3..000000000
--- a/Sources/PhishingDetection/PhishingDetectionUpdateManager.swift
+++ /dev/null
@@ -1,83 +0,0 @@
-//
-// PhishingDetectionUpdateManager.swift
-//
-// Copyright © 2023 DuckDuckGo. All rights reserved.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-//
-
-import Foundation
-import Common
-import os
-
-public protocol PhishingDetectionUpdateManaging {
- func updateFilterSet() async
- func updateHashPrefixes() async
-}
-
-public class PhishingDetectionUpdateManager: PhishingDetectionUpdateManaging {
- var apiClient: PhishingDetectionClientProtocol
- var dataStore: PhishingDetectionDataSaving
-
- public init(client: PhishingDetectionClientProtocol, dataStore: PhishingDetectionDataSaving) {
- self.apiClient = client
- self.dataStore = dataStore
- }
-
- private func updateSet(
- currentSet: Set,
- insert: [T],
- delete: [T],
- replace: Bool,
- saveSet: (Set) -> Void
- ) {
- var newSet = currentSet
-
- if replace {
- newSet = Set(insert)
- } else {
- newSet.formUnion(insert)
- newSet.subtract(delete)
- }
-
- saveSet(newSet)
- }
-
- public func updateFilterSet() async {
- let response = await apiClient.getFilterSet(revision: dataStore.currentRevision)
- updateSet(
- currentSet: dataStore.filterSet,
- insert: response.insert,
- delete: response.delete,
- replace: response.replace
- ) { newSet in
- self.dataStore.saveFilterSet(set: newSet)
- }
- dataStore.saveRevision(response.revision)
- Logger.phishingDetectionUpdateManager.debug("filterSet updated to revision \(self.dataStore.currentRevision)")
- }
-
- public func updateHashPrefixes() async {
- let response = await apiClient.getHashPrefixes(revision: dataStore.currentRevision)
- updateSet(
- currentSet: dataStore.hashPrefixes,
- insert: response.insert,
- delete: response.delete,
- replace: response.replace
- ) { newSet in
- self.dataStore.saveHashPrefixes(set: newSet)
- }
- dataStore.saveRevision(response.revision)
- Logger.phishingDetectionUpdateManager.debug("hashPrefixes updated to revision \(self.dataStore.currentRevision)")
- }
-}
diff --git a/Sources/PhishingDetection/PhishingDetector.swift b/Sources/PhishingDetection/PhishingDetector.swift
deleted file mode 100644
index 3ccbe9b7e..000000000
--- a/Sources/PhishingDetection/PhishingDetector.swift
+++ /dev/null
@@ -1,130 +0,0 @@
-//
-// PhishingDetector.swift
-//
-// Copyright © 2023 DuckDuckGo. All rights reserved.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-//
-
-import Foundation
-import CryptoKit
-import Common
-import WebKit
-
-public enum PhishingDetectionError: CustomNSError {
- case detected
-
- public static let errorDomain: String = "PhishingDetectionError"
-
- public var errorCode: Int {
- switch self {
- case .detected:
- return 1331
- }
- }
-
- public var errorUserInfo: [String: Any] {
- switch self {
- case .detected:
- return [NSLocalizedDescriptionKey: "Phishing detected"]
- }
- }
-
- public var rawValue: Int {
- return self.errorCode
- }
-}
-
-public protocol PhishingDetecting {
- func isMalicious(url: URL) async -> Bool
-}
-
-public class PhishingDetector: PhishingDetecting {
- let hashPrefixStoreLength: Int = 8
- let hashPrefixParamLength: Int = 4
- let apiClient: PhishingDetectionClientProtocol
- let dataStore: PhishingDetectionDataSaving
- let eventMapping: EventMapping
-
- public init(apiClient: PhishingDetectionClientProtocol, dataStore: PhishingDetectionDataSaving, eventMapping: EventMapping) {
- self.apiClient = apiClient
- self.dataStore = dataStore
- self.eventMapping = eventMapping
- }
-
- private func getMatches(hashPrefix: String) async -> Set {
- return Set(await apiClient.getMatches(hashPrefix: hashPrefix))
- }
-
- private func inFilterSet(hash: String) -> Set {
- return Set(dataStore.filterSet.filter { $0.hashValue == hash })
- }
-
- private func matchesUrl(hash: String, regexPattern: String, url: URL, hostnameHash: String) -> Bool {
- if hash == hostnameHash,
- let regex = try? NSRegularExpression(pattern: regexPattern, options: []) {
- let urlString = url.absoluteString
- let range = NSRange(location: 0, length: urlString.utf16.count)
- return regex.firstMatch(in: urlString, options: [], range: range) != nil
- }
- return false
- }
-
- private func generateHashPrefix(for canonicalHost: String, length: Int) -> String {
- let hostnameHash = SHA256.hash(data: Data(canonicalHost.utf8)).map { String(format: "%02hhx", $0) }.joined()
- return String(hostnameHash.prefix(length))
- }
-
- private func fetchMatches(hashPrefix: String) async -> [Match] {
- return await apiClient.getMatches(hashPrefix: hashPrefix)
- }
-
- private func checkLocalFilters(canonicalHost: String, canonicalUrl: URL) -> Bool {
- let hostnameHash = generateHashPrefix(for: canonicalHost, length: Int.max)
- let filterHit = inFilterSet(hash: hostnameHash)
- for filter in filterHit where matchesUrl(hash: filter.hashValue, regexPattern: filter.regex, url: canonicalUrl, hostnameHash: hostnameHash) {
- eventMapping.fire(PhishingDetectionEvents.errorPageShown(clientSideHit: true))
- return true
- }
- return false
- }
-
- private func checkApiMatches(canonicalHost: String, canonicalUrl: URL) async -> Bool {
- let hashPrefixParam = generateHashPrefix(for: canonicalHost, length: hashPrefixParamLength)
- let matches = await fetchMatches(hashPrefix: hashPrefixParam)
- let hostnameHash = generateHashPrefix(for: canonicalHost, length: Int.max)
- for match in matches where matchesUrl(hash: match.hash, regexPattern: match.regex, url: canonicalUrl, hostnameHash: hostnameHash) {
- eventMapping.fire(PhishingDetectionEvents.errorPageShown(clientSideHit: false))
- return true
- }
- return false
- }
-
- public func isMalicious(url: URL) async -> Bool {
- guard let canonicalHost = url.canonicalHost(), let canonicalUrl = url.canonicalURL() else { return false }
-
- let hashPrefix = generateHashPrefix(for: canonicalHost, length: hashPrefixStoreLength)
- if dataStore.hashPrefixes.contains(hashPrefix) {
- // Check local filterSet first
- if checkLocalFilters(canonicalHost: canonicalHost, canonicalUrl: canonicalUrl) {
- return true
- }
- // If nothing found, hit the API to get matches
- if await checkApiMatches(canonicalHost: canonicalHost, canonicalUrl: canonicalUrl) {
- return true
- }
- }
-
- return false
- }
-}
diff --git a/Sources/PrivacyDashboard/PrivacyDashboardController.swift b/Sources/PrivacyDashboard/PrivacyDashboardController.swift
index 8093a02d2..988c4cd20 100644
--- a/Sources/PrivacyDashboard/PrivacyDashboardController.swift
+++ b/Sources/PrivacyDashboard/PrivacyDashboardController.swift
@@ -16,12 +16,13 @@
// limitations under the License.
//
-import Foundation
-import WebKit
-import Combine
-import PrivacyDashboardResources
import BrowserServicesKit
+import Combine
import Common
+import Foundation
+import MaliciousSiteProtection
+import PrivacyDashboardResources
+import WebKit
public enum PrivacyDashboardOpenSettingsTarget: String {
@@ -205,7 +206,7 @@ extension PrivacyDashboardController: WKNavigationDelegate {
subscribeToServerTrust()
subscribeToConsentManaged()
subscribeToAllowedPermissions()
- subscribeToIsPhishing()
+ subscribeToMaliciousSiteThreatKind()
}
private func subscribeToTheme() {
@@ -259,12 +260,12 @@ extension PrivacyDashboardController: WKNavigationDelegate {
.store(in: &cancellables)
}
- private func subscribeToIsPhishing() {
- privacyInfo?.$isPhishing
+ private func subscribeToMaliciousSiteThreatKind() {
+ privacyInfo?.$malicousSiteThreatKind
.receive(on: DispatchQueue.main )
- .sink(receiveValue: { [weak self] isPhishing in
- guard let self = self, let webView = self.webView else { return }
- script.setIsPhishing(isPhishing, webView: webView)
+ .sink(receiveValue: { [weak self] detectedThreatKind in
+ guard let self, let webView else { return }
+ script.setMaliciousSiteDetectedThreatKind(detectedThreatKind, webView: webView)
})
.store(in: &cancellables)
}
diff --git a/Sources/PrivacyDashboard/PrivacyDashboardUserScript.swift b/Sources/PrivacyDashboard/PrivacyDashboardUserScript.swift
index 23fc5e738..801cdd81c 100644
--- a/Sources/PrivacyDashboard/PrivacyDashboardUserScript.swift
+++ b/Sources/PrivacyDashboard/PrivacyDashboardUserScript.swift
@@ -16,12 +16,13 @@
// limitations under the License.
//
+import BrowserServicesKit
+import Common
import Foundation
-import WebKit
+import MaliciousSiteProtection
import TrackerRadarKit
import UserScript
-import Common
-import BrowserServicesKit
+import WebKit
@MainActor
protocol PrivacyDashboardUserScriptDelegate: AnyObject {
@@ -425,8 +426,8 @@ final class PrivacyDashboardUserScript: NSObject, StaticUserScript {
evaluate(js: "window.onChangeCertificateData(\(certificateDataJson))", in: webView)
}
- func setIsPhishing(_ isPhishing: Bool, webView: WKWebView) {
- let phishingStatus = ["phishingStatus": isPhishing]
+ func setMaliciousSiteDetectedThreatKind(_ detectedThreatKind: MaliciousSiteProtection.ThreatKind?, webView: WKWebView) {
+ let phishingStatus = ["phishingStatus": detectedThreatKind == .phishing]
guard let phishingStatusJson = try? JSONEncoder().encode(phishingStatus).utf8String() else {
assertionFailure("Can't encode phishingStatus into JSON")
return
diff --git a/Sources/PrivacyDashboard/PrivacyInfo.swift b/Sources/PrivacyDashboard/PrivacyInfo.swift
index b9db906fc..3eaabc185 100644
--- a/Sources/PrivacyDashboard/PrivacyInfo.swift
+++ b/Sources/PrivacyDashboard/PrivacyInfo.swift
@@ -16,9 +16,10 @@
// limitations under the License.
//
+import Common
import Foundation
+import MaliciousSiteProtection
import TrackerRadarKit
-import Common
public protocol SecurityTrust { }
extension SecTrust: SecurityTrust {}
@@ -33,15 +34,15 @@ public final class PrivacyInfo {
@Published public var serverTrust: SecurityTrust?
@Published public var connectionUpgradedTo: URL?
@Published public var cookieConsentManaged: CookieConsentInfo?
- @Published public var isPhishing: Bool
+ @Published public var malicousSiteThreatKind: MaliciousSiteProtection.ThreatKind?
@Published public var isSpecialErrorPageVisible: Bool = false
@Published public var shouldCheckServerTrust: Bool
- public init(url: URL, parentEntity: Entity?, protectionStatus: ProtectionStatus, isPhishing: Bool = false, shouldCheckServerTrust: Bool = false) {
+ public init(url: URL, parentEntity: Entity?, protectionStatus: ProtectionStatus, malicousSiteThreatKind: MaliciousSiteProtection.ThreatKind? = .none, shouldCheckServerTrust: Bool = false) {
self.url = url
self.parentEntity = parentEntity
self.protectionStatus = protectionStatus
- self.isPhishing = isPhishing
+ self.malicousSiteThreatKind = malicousSiteThreatKind
self.shouldCheckServerTrust = shouldCheckServerTrust
trackerInfo = TrackerInfo()
diff --git a/Sources/PrivacyStats/Logger+PrivacyStats.swift b/Sources/PrivacyStats/Logger+PrivacyStats.swift
new file mode 100644
index 000000000..e8649a6af
--- /dev/null
+++ b/Sources/PrivacyStats/Logger+PrivacyStats.swift
@@ -0,0 +1,24 @@
+//
+// Logger+PrivacyStats.swift
+//
+// Copyright © 2024 DuckDuckGo. All rights reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//
+
+import Foundation
+import os.log
+
+public extension Logger {
+ static var privacyStats = { Logger(subsystem: "Privacy Stats", category: "") }()
+}
diff --git a/Sources/PrivacyStats/PrivacyStats.swift b/Sources/PrivacyStats/PrivacyStats.swift
new file mode 100644
index 000000000..c298f60fc
--- /dev/null
+++ b/Sources/PrivacyStats/PrivacyStats.swift
@@ -0,0 +1,249 @@
+//
+// PrivacyStats.swift
+//
+// Copyright © 2024 DuckDuckGo. All rights reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//
+
+import Combine
+import Common
+import CoreData
+import Foundation
+import os.log
+import Persistence
+import TrackerRadarKit
+
+/**
+ * Errors that may be reported by `PrivacyStats`.
+ */
+public enum PrivacyStatsError: CustomNSError {
+ case failedToFetchPrivacyStatsSummary(Error)
+ case failedToStorePrivacyStats(Error)
+ case failedToLoadCurrentPrivacyStats(Error)
+
+ public static let errorDomain: String = "PrivacyStatsError"
+
+ public var errorCode: Int {
+ switch self {
+ case .failedToFetchPrivacyStatsSummary:
+ return 1
+ case .failedToStorePrivacyStats:
+ return 2
+ case .failedToLoadCurrentPrivacyStats:
+ return 3
+ }
+ }
+
+ public var underlyingError: Error {
+ switch self {
+ case .failedToFetchPrivacyStatsSummary(let error),
+ .failedToStorePrivacyStats(let error),
+ .failedToLoadCurrentPrivacyStats(let error):
+ return error
+ }
+ }
+}
+
+/**
+ * This protocol describes database provider consumed by `PrivacyStats`.
+ */
+public protocol PrivacyStatsDatabaseProviding {
+ func initializeDatabase() -> CoreDataDatabase
+}
+
+/**
+ * This protocol describes `PrivacyStats` interface.
+ */
+public protocol PrivacyStatsCollecting {
+
+ /**
+ * Record a tracker for a given `companyName`.
+ *
+ * `PrivacyStats` implementation calls the `CurrentPack` actor under the hood,
+ * and as such it can safely be called on multiple threads concurrently.
+ */
+ func recordBlockedTracker(_ name: String) async
+
+ /**
+ * Publisher emitting values whenever updated privacy stats were persisted to disk.
+ */
+ var statsUpdatePublisher: AnyPublisher { get }
+
+ /**
+ * This function fetches privacy stats in a dictionary format
+ * with keys being company names and values being total number
+ * of tracking attempts blocked in past 7 days.
+ */
+ func fetchPrivacyStats() async -> [String: Int64]
+
+ /**
+ * This function clears all blocked tracker stats from the database.
+ */
+ func clearPrivacyStats() async
+
+ /**
+ * This function saves all pending changes to the persistent storage.
+ *
+ * It should only be used in response to app termination because otherwise
+ * the `PrivacyStats` object schedules persisting internally.
+ */
+ func handleAppTermination() async
+}
+
+public final class PrivacyStats: PrivacyStatsCollecting {
+
+ public static let bundle = Bundle.module
+
+ public let statsUpdatePublisher: AnyPublisher
+
+ private let db: CoreDataDatabase
+ private let context: NSManagedObjectContext
+ private let currentPack: CurrentPack
+ private let statsUpdateSubject = PassthroughSubject()
+ private var cancellables: Set = []
+
+ private let errorEvents: EventMapping?
+
+ public init(databaseProvider: PrivacyStatsDatabaseProviding, errorEvents: EventMapping? = nil) {
+ self.db = databaseProvider.initializeDatabase()
+ self.context = db.makeContext(concurrencyType: .privateQueueConcurrencyType, name: "PrivacyStats")
+ self.errorEvents = errorEvents
+ self.currentPack = .init(pack: Self.initializeCurrentPack(in: context, errorEvents: errorEvents))
+ statsUpdatePublisher = statsUpdateSubject.eraseToAnyPublisher()
+
+ currentPack.commitChangesPublisher
+ .sink { [weak self] pack in
+ Task {
+ await self?.commitChanges(pack)
+ }
+ }
+ .store(in: &cancellables)
+ }
+
+ public func recordBlockedTracker(_ companyName: String) async {
+ await currentPack.recordBlockedTracker(companyName)
+ }
+
+ public func fetchPrivacyStats() async -> [String: Int64] {
+ return await withCheckedContinuation { continuation in
+ context.perform { [weak self] in
+ guard let self else {
+ continuation.resume(returning: [:])
+ return
+ }
+ do {
+ let stats = try PrivacyStatsUtils.load7DayStats(in: context)
+ continuation.resume(returning: stats)
+ } catch {
+ errorEvents?.fire(.failedToFetchPrivacyStatsSummary(error))
+ continuation.resume(returning: [:])
+ }
+ }
+ }
+ }
+
+ public func clearPrivacyStats() async {
+ await withCheckedContinuation { continuation in
+ context.perform { [weak self] in
+ guard let self else {
+ continuation.resume()
+ return
+ }
+ do {
+ try PrivacyStatsUtils.deleteAllStats(in: context)
+ Logger.privacyStats.debug("Deleted outdated entries")
+ } catch {
+ Logger.privacyStats.error("Save error: \(error)")
+ errorEvents?.fire(.failedToFetchPrivacyStatsSummary(error))
+ }
+ continuation.resume()
+ }
+ }
+ await currentPack.resetPack()
+ statsUpdateSubject.send()
+ }
+
+ public func handleAppTermination() async {
+ await commitChanges(currentPack.pack)
+ }
+
+ // MARK: - Private
+
+ private func commitChanges(_ pack: PrivacyStatsPack) async {
+ await withCheckedContinuation { continuation in
+ context.perform { [weak self] in
+ guard let self else {
+ continuation.resume()
+ return
+ }
+
+ // Check if the pack we're currently storing is from a previous day.
+ let isCurrentDayPack = pack.timestamp == Date.currentPrivacyStatsPackTimestamp
+
+ do {
+ let statsObjects = try PrivacyStatsUtils.fetchOrInsertCurrentStats(for: Set(pack.trackers.keys), in: context)
+ statsObjects.forEach { stats in
+ if let count = pack.trackers[stats.companyName] {
+ stats.count = count
+ }
+ }
+
+ guard context.hasChanges else {
+ continuation.resume()
+ return
+ }
+
+ try context.save()
+ Logger.privacyStats.debug("Saved stats \(pack.timestamp) \(pack.trackers)")
+
+ if isCurrentDayPack {
+ // Only emit update event when saving current-day pack. For previous-day pack,
+ // a follow-up commit event will come and we'll emit the update then.
+ statsUpdateSubject.send()
+ } else {
+ // When storing a pack from a previous day, we may have outdated packs, so delete them as needed.
+ try PrivacyStatsUtils.deleteOutdatedPacks(in: context)
+ }
+ } catch {
+ Logger.privacyStats.error("Save error: \(error)")
+ errorEvents?.fire(.failedToStorePrivacyStats(error))
+ }
+ continuation.resume()
+ }
+ }
+ }
+
+ /**
+ * This function is only called in the initializer. It performs a blocking call to the database
+ * to spare us the hassle of declaring the initializer async or spawning tasks from within the
+ * initializer without being able to await them, thus making testing trickier.
+ */
+ private static func initializeCurrentPack(in context: NSManagedObjectContext, errorEvents: EventMapping?) -> PrivacyStatsPack {
+ var pack: PrivacyStatsPack?
+ context.performAndWait {
+ let timestamp = Date.currentPrivacyStatsPackTimestamp
+ do {
+ let currentDayStats = try PrivacyStatsUtils.loadCurrentDayStats(in: context)
+ Logger.privacyStats.debug("Loaded stats \(timestamp) \(currentDayStats)")
+ pack = PrivacyStatsPack(timestamp: timestamp, trackers: currentDayStats)
+
+ try PrivacyStatsUtils.deleteOutdatedPacks(in: context)
+ } catch {
+ Logger.privacyStats.error("Failed to load current stats: \(error)")
+ errorEvents?.fire(.failedToLoadCurrentPrivacyStats(error))
+ }
+ }
+ return pack ?? PrivacyStatsPack(timestamp: Date.currentPrivacyStatsPackTimestamp)
+ }
+}
diff --git a/Sources/PrivacyStats/PrivacyStats.xcdatamodeld/.xccurrentversion b/Sources/PrivacyStats/PrivacyStats.xcdatamodeld/.xccurrentversion
new file mode 100644
index 000000000..1a19d1654
--- /dev/null
+++ b/Sources/PrivacyStats/PrivacyStats.xcdatamodeld/.xccurrentversion
@@ -0,0 +1,8 @@
+
+
+
+
+ _XCCurrentVersionName
+ PrivacyStats.xcdatamodel
+
+
diff --git a/Sources/PrivacyStats/PrivacyStats.xcdatamodeld/PrivacyStats.xcdatamodel/contents b/Sources/PrivacyStats/PrivacyStats.xcdatamodeld/PrivacyStats.xcdatamodel/contents
new file mode 100644
index 000000000..39798857a
--- /dev/null
+++ b/Sources/PrivacyStats/PrivacyStats.xcdatamodeld/PrivacyStats.xcdatamodel/contents
@@ -0,0 +1,18 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/Sources/PrivacyStats/internal/CurrentPack.swift b/Sources/PrivacyStats/internal/CurrentPack.swift
new file mode 100644
index 000000000..fffc6b090
--- /dev/null
+++ b/Sources/PrivacyStats/internal/CurrentPack.swift
@@ -0,0 +1,116 @@
+//
+// CurrentPack.swift
+//
+// Copyright © 2024 DuckDuckGo. All rights reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//
+
+import Combine
+import Foundation
+import os.log
+
+/**
+ * This actor provides thread-safe access to an instance of `PrivacyStatsPack`.
+ *
+ * It's used by `PrivacyStats` class to record blocked trackers that can possibly
+ * come from multiple open tabs (web views) at the same time.
+ */
+actor CurrentPack {
+ /**
+ * Current stats pack.
+ */
+ private(set) var pack: PrivacyStatsPack
+
+ /**
+ * Publisher that fires events whenever tracker stats are ready to be persisted to disk.
+ *
+ * This happens after recording new blocked tracker, when no new tracker has been recorded for 1s.
+ */
+ nonisolated private(set) lazy var commitChangesPublisher: AnyPublisher = commitChangesSubject.eraseToAnyPublisher()
+
+ nonisolated private let commitChangesSubject = PassthroughSubject()
+ private var commitTask: Task?
+ private var commitDebounce: UInt64
+
+ /// The `commitDebounce` parameter should only be modified in unit tests.
+ init(pack: PrivacyStatsPack, commitDebounce: UInt64 = 1_000_000_000) {
+ self.pack = pack
+ self.commitDebounce = commitDebounce
+ }
+
+ deinit {
+ commitTask?.cancel()
+ }
+
+ /**
+ * This function is used when clearing app data, to clear any stats cached in memory.
+ *
+ * It sets a new empty pack with the current timestamp.
+ */
+ func resetPack() {
+ resetStats(andSet: Date.currentPrivacyStatsPackTimestamp)
+ }
+
+ /**
+ * This function increments trackers count for a given company name.
+ *
+ * Updates are kept in memory and scheduled for saving to persistent storage with 1s debounce.
+ * This function also detects when the current pack becomes outdated (which happens
+ * when current timestamp's day becomes greater than pack's timestamp's day), in which
+ * case current pack is scheduled for persisting on disk and a new empty pack is
+ * created for the new timestamp.
+ */
+ func recordBlockedTracker(_ companyName: String) {
+
+ let currentTimestamp = Date.currentPrivacyStatsPackTimestamp
+ if currentTimestamp != pack.timestamp {
+ Logger.privacyStats.debug("New timestamp detected, storing trackers state and creating new pack")
+ notifyChanges(for: pack, immediately: true)
+ resetStats(andSet: currentTimestamp)
+ }
+
+ let count = pack.trackers[companyName] ?? 0
+ pack.trackers[companyName] = count + 1
+
+ notifyChanges(for: pack, immediately: false)
+ }
+
+ private func notifyChanges(for pack: PrivacyStatsPack, immediately shouldPublishImmediately: Bool) {
+ commitTask?.cancel()
+
+ if shouldPublishImmediately {
+
+ commitChangesSubject.send(pack)
+
+ } else {
+
+ commitTask = Task {
+ do {
+ // Note that this doesn't always sleep for the full debounce time, but the sleep is interrupted
+ // as soon as the task gets cancelled.
+ try await Task.sleep(nanoseconds: commitDebounce)
+
+ Logger.privacyStats.debug("Storing trackers state")
+ commitChangesSubject.send(pack)
+ } catch {
+ // Commit task got cancelled
+ }
+ }
+ }
+ }
+
+ private func resetStats(andSet newTimestamp: Date) {
+ pack = PrivacyStatsPack(timestamp: newTimestamp, trackers: [:])
+ }
+}
diff --git a/Sources/PrivacyStats/internal/DailyBlockedTrackersEntity.swift b/Sources/PrivacyStats/internal/DailyBlockedTrackersEntity.swift
new file mode 100644
index 000000000..728a5943c
--- /dev/null
+++ b/Sources/PrivacyStats/internal/DailyBlockedTrackersEntity.swift
@@ -0,0 +1,55 @@
+//
+// DailyBlockedTrackersEntity.swift
+//
+// Copyright © 2024 DuckDuckGo. All rights reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//
+
+import Common
+import CoreData
+
+@objc(DailyBlockedTrackersEntity)
+final class DailyBlockedTrackersEntity: NSManagedObject {
+ enum Const {
+ static let entityName = "DailyBlockedTrackersEntity"
+ }
+
+ @nonobjc class func fetchRequest() -> NSFetchRequest {
+ NSFetchRequest(entityName: Const.entityName)
+ }
+
+ class func entity(in context: NSManagedObjectContext) -> NSEntityDescription {
+ NSEntityDescription.entity(forEntityName: Const.entityName, in: context)!
+ }
+
+ @NSManaged var companyName: String
+ @NSManaged var count: Int64
+ @NSManaged var timestamp: Date
+
+ private override init(entity: NSEntityDescription, insertInto context: NSManagedObjectContext?) {
+ super.init(entity: entity, insertInto: context)
+ }
+
+ private convenience init(context moc: NSManagedObjectContext) {
+ self.init(entity: DailyBlockedTrackersEntity.entity(in: moc), insertInto: moc)
+ }
+
+ static func make(timestamp: Date = Date(), companyName: String, count: Int64 = 0, context: NSManagedObjectContext) -> DailyBlockedTrackersEntity {
+ let object = DailyBlockedTrackersEntity(context: context)
+ object.timestamp = timestamp.privacyStatsPackTimestamp
+ object.companyName = companyName
+ object.count = count
+ return object
+ }
+}
diff --git a/Sources/PrivacyStats/internal/Date+PrivacyStats.swift b/Sources/PrivacyStats/internal/Date+PrivacyStats.swift
new file mode 100644
index 000000000..f56072d75
--- /dev/null
+++ b/Sources/PrivacyStats/internal/Date+PrivacyStats.swift
@@ -0,0 +1,51 @@
+//
+// Date+PrivacyStats.swift
+//
+// Copyright © 2024 DuckDuckGo. All rights reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//
+
+import Common
+import Foundation
+
+extension Date {
+
+ /**
+ * Returns privacy stats pack timestamp for the current date.
+ *
+ * See `privacyStatsPackTimestamp`.
+ */
+ static var currentPrivacyStatsPackTimestamp: Date {
+ Date().privacyStatsPackTimestamp
+ }
+
+ /**
+ * Returns a valid timestamp for `DailyBlockedTrackersEntity` instance matching the sender.
+ *
+ * Blocked trackers are packed by day so the timestap of the pack must be the exact start of a day.
+ */
+ var privacyStatsPackTimestamp: Date {
+ startOfDay
+ }
+
+ /**
+ * Returns the oldest valid timestamp for `DailyBlockedTrackersEntity` instance matching the sender.
+ *
+ * Privacy Stats only keeps track of 7 days worth of tracking history, so the oldest timestamp is
+ * beginning of the day 6 days ago.
+ */
+ var privacyStatsOldestPackTimestamp: Date {
+ privacyStatsPackTimestamp.daysAgo(6)
+ }
+}
diff --git a/Sources/PrivacyStats/internal/PrivacyStatsPack.swift b/Sources/PrivacyStats/internal/PrivacyStatsPack.swift
new file mode 100644
index 000000000..3c4c8a04b
--- /dev/null
+++ b/Sources/PrivacyStats/internal/PrivacyStatsPack.swift
@@ -0,0 +1,32 @@
+//
+// PrivacyStatsPack.swift
+//
+// Copyright © 2024 DuckDuckGo. All rights reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//
+
+import Foundation
+
+/**
+ * This struct keeps track of the summary of blocked trackers for a single unit of time (1 day).
+ */
+struct PrivacyStatsPack: Equatable {
+ let timestamp: Date
+ var trackers: [String: Int64]
+
+ init(timestamp: Date, trackers: [String: Int64] = [:]) {
+ self.timestamp = timestamp
+ self.trackers = trackers
+ }
+}
diff --git a/Sources/PrivacyStats/internal/PrivacyStatsUtils.swift b/Sources/PrivacyStats/internal/PrivacyStatsUtils.swift
new file mode 100644
index 000000000..33b93d869
--- /dev/null
+++ b/Sources/PrivacyStats/internal/PrivacyStatsUtils.swift
@@ -0,0 +1,123 @@
+//
+// PrivacyStatsUtils.swift
+//
+// Copyright © 2024 DuckDuckGo. All rights reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//
+
+import Common
+import CoreData
+import Foundation
+import Persistence
+
+final class PrivacyStatsUtils {
+
+ /**
+ * Returns objects corresponding to current stats for companies specified by `companyNames`.
+ *
+ * If an object doesn't exist (no trackers for a given company were reported on a given day)
+ * then a new object for that company is inserted into the context and returned.
+ * If a user opens the app for the first time on a given day, the database will not contain
+ * any records for that day and this function will only insert new objects.
+ *
+ * > Note: `current stats` refer to stats objects that are active on a given day, i.e. their
+ * timestamp's day matches current day.
+ */
+ static func fetchOrInsertCurrentStats(for companyNames: Set, in context: NSManagedObjectContext) throws -> [DailyBlockedTrackersEntity] {
+ let timestamp = Date.currentPrivacyStatsPackTimestamp
+
+ let request = DailyBlockedTrackersEntity.fetchRequest()
+ request.predicate = NSPredicate(format: "%K == %@ AND %K in %@",
+ #keyPath(DailyBlockedTrackersEntity.timestamp), timestamp as NSDate,
+ #keyPath(DailyBlockedTrackersEntity.companyName), companyNames)
+ request.returnsObjectsAsFaults = false
+
+ var statsObjects = try context.fetch(request)
+ let missingCompanyNames = companyNames.subtracting(statsObjects.map(\.companyName))
+
+ for companyName in missingCompanyNames {
+ statsObjects.append(DailyBlockedTrackersEntity.make(timestamp: timestamp, companyName: companyName, context: context))
+ }
+ return statsObjects
+ }
+
+ /**
+ * Returns a dictionary representation of blocked trackers counts grouped by company name for the current day.
+ */
+ static func loadCurrentDayStats(in context: NSManagedObjectContext) throws -> [String: Int64] {
+ let startDate = Date.currentPrivacyStatsPackTimestamp
+ return try loadBlockedTrackersStats(since: startDate, in: context)
+ }
+
+ /**
+ * Returns a dictionary representation of blocked trackers counts grouped by company name for past 7 days.
+ */
+ static func load7DayStats(in context: NSManagedObjectContext) throws -> [String: Int64] {
+ let startDate = Date().privacyStatsOldestPackTimestamp
+ return try loadBlockedTrackersStats(since: startDate, in: context)
+ }
+
+ private static func loadBlockedTrackersStats(since startDate: Date, in context: NSManagedObjectContext) throws -> [String: Int64] {
+ let request = NSFetchRequest(entityName: DailyBlockedTrackersEntity.Const.entityName)
+ request.predicate = NSPredicate(format: "%K >= %@", #keyPath(DailyBlockedTrackersEntity.timestamp), startDate as NSDate)
+
+ let companyNameKey = #keyPath(DailyBlockedTrackersEntity.companyName)
+
+ // Expression description for the sum of count
+ let countExpression = NSExpression(forKeyPath: #keyPath(DailyBlockedTrackersEntity.count))
+ let sumExpression = NSExpression(forFunction: "sum:", arguments: [countExpression])
+
+ let sumExpressionDescription = NSExpressionDescription()
+ sumExpressionDescription.name = "totalCount"
+ sumExpressionDescription.expression = sumExpression
+ sumExpressionDescription.expressionResultType = .integer64AttributeType
+
+ request.propertiesToGroupBy = [companyNameKey]
+ request.propertiesToFetch = [companyNameKey, sumExpressionDescription]
+ request.resultType = .dictionaryResultType
+
+ let results = (try context.fetch(request) as? [[String: Any]]) ?? []
+
+ let groupedResults = results.reduce(into: [String: Int64]()) { partialResult, result in
+ if let companyName = result[companyNameKey] as? String, let totalCount = result["totalCount"] as? Int64, totalCount > 0 {
+ partialResult[companyName] = totalCount
+ }
+ }
+
+ return groupedResults
+ }
+
+ /**
+ * Deletes stats older than 7 days for all companies.
+ */
+ static func deleteOutdatedPacks(in context: NSManagedObjectContext) throws {
+ let oldestValidTimestamp = Date().privacyStatsOldestPackTimestamp
+
+ let fetchRequest = NSFetchRequest(entityName: DailyBlockedTrackersEntity.Const.entityName)
+ fetchRequest.predicate = NSPredicate(format: "%K < %@", #keyPath(DailyBlockedTrackersEntity.timestamp), oldestValidTimestamp as NSDate)
+ let deleteRequest = NSBatchDeleteRequest(fetchRequest: fetchRequest)
+
+ try context.execute(deleteRequest)
+ context.reset()
+ }
+
+ /**
+ * Deletes all stats entries in the database.
+ */
+ static func deleteAllStats(in context: NSManagedObjectContext) throws {
+ let deleteRequest = NSBatchDeleteRequest(fetchRequest: DailyBlockedTrackersEntity.fetchRequest())
+ try context.execute(deleteRequest)
+ context.reset()
+ }
+}
diff --git a/Sources/RemoteMessaging/Mappers/DefaultRemoteMessagingSurveyURLBuilder.swift b/Sources/RemoteMessaging/Mappers/DefaultRemoteMessagingSurveyURLBuilder.swift
index dc3f3e48b..37e53e352 100644
--- a/Sources/RemoteMessaging/Mappers/DefaultRemoteMessagingSurveyURLBuilder.swift
+++ b/Sources/RemoteMessaging/Mappers/DefaultRemoteMessagingSurveyURLBuilder.swift
@@ -31,11 +31,16 @@ public struct DefaultRemoteMessagingSurveyURLBuilder: RemoteMessagingSurveyActio
private let statisticsStore: StatisticsStore
private let vpnActivationDateStore: VPNActivationDateProviding
private let subscription: Subscription?
+ private let localeIdentifier: String
- public init(statisticsStore: StatisticsStore, vpnActivationDateStore: VPNActivationDateProviding, subscription: Subscription?) {
+ public init(statisticsStore: StatisticsStore,
+ vpnActivationDateStore: VPNActivationDateProviding,
+ subscription: Subscription?,
+ localeIdentifier: String = Locale.current.identifier) {
self.statisticsStore = statisticsStore
self.vpnActivationDateStore = vpnActivationDateStore
self.subscription = subscription
+ self.localeIdentifier = localeIdentifier
}
// swiftlint:disable:next cyclomatic_complexity
@@ -72,6 +77,9 @@ public struct DefaultRemoteMessagingSurveyURLBuilder: RemoteMessagingSurveyActio
let daysSinceInstall = Calendar.current.numberOfDaysBetween(installDate, and: Date()) {
queryItems.append(URLQueryItem(name: parameter.rawValue, value: String(describing: daysSinceInstall)))
}
+ case .locale:
+ let formattedLocale = LocaleMatchingAttribute.localeIdentifierAsJsonFormat(localeIdentifier)
+ queryItems.append(URLQueryItem(name: parameter.rawValue, value: formattedLocale))
case .privacyProStatus:
if let privacyProStatusSurveyParameter = subscription?.privacyProStatusSurveyParameter {
queryItems.append(URLQueryItem(name: parameter.rawValue, value: privacyProStatusSurveyParameter))
diff --git a/Sources/RemoteMessaging/Mappers/RemoteMessagingSurveyActionMapping.swift b/Sources/RemoteMessaging/Mappers/RemoteMessagingSurveyActionMapping.swift
index 668b03e1c..a927e1c47 100644
--- a/Sources/RemoteMessaging/Mappers/RemoteMessagingSurveyActionMapping.swift
+++ b/Sources/RemoteMessaging/Mappers/RemoteMessagingSurveyActionMapping.swift
@@ -24,6 +24,7 @@ public enum RemoteMessagingSurveyActionParameter: String, CaseIterable {
case atbVariant = "var"
case daysInstalled = "delta"
case hardwareModel = "mo"
+ case locale = "locale"
case osVersion = "osv"
case privacyProStatus = "ppro_status"
case privacyProPlatform = "ppro_platform"
diff --git a/Sources/SpecialErrorPages/SSLErrorType.swift b/Sources/SpecialErrorPages/SSLErrorType.swift
index cf483b8e2..58137a293 100644
--- a/Sources/SpecialErrorPages/SSLErrorType.swift
+++ b/Sources/SpecialErrorPages/SSLErrorType.swift
@@ -17,28 +17,27 @@
//
import Foundation
+import WebKit
-public enum SSLErrorType: String {
+public let SSLErrorCodeKey = "_kCFStreamErrorCodeKey"
+
+public enum SSLErrorType: String, Encodable {
case expired
- case wrongHost
case selfSigned
+ case wrongHost
case invalid
- public static func forErrorCode(_ errorCode: Int) -> Self {
- switch Int32(errorCode) {
- case errSSLCertExpired:
- return .expired
- case errSSLHostNameMismatch:
- return .wrongHost
- case errSSLXCertChainInvalid:
- return .selfSigned
- default:
- return .invalid
+ init(errorCode: Int32) {
+ self = switch errorCode {
+ case errSSLCertExpired: .expired
+ case errSSLXCertChainInvalid: .selfSigned
+ case errSSLHostNameMismatch: .wrongHost
+ default: .invalid
}
}
- public var rawParameter: String {
+ public var pixelParameter: String {
switch self {
case .expired: return "expired"
case .wrongHost: return "wrong_host"
@@ -48,3 +47,16 @@ public enum SSLErrorType: String {
}
}
+
+extension WKError {
+ public var sslErrorType: SSLErrorType? {
+ _nsError.sslErrorType
+ }
+}
+extension NSError {
+ public var sslErrorType: SSLErrorType? {
+ guard let errorCode = self.userInfo[SSLErrorCodeKey] as? Int32 else { return nil }
+ let sslErrorType = SSLErrorType(errorCode: errorCode)
+ return sslErrorType
+ }
+}
diff --git a/Sources/SpecialErrorPages/SpecialErrorData.swift b/Sources/SpecialErrorPages/SpecialErrorData.swift
index 7ceb0baef..048077847 100644
--- a/Sources/SpecialErrorPages/SpecialErrorData.swift
+++ b/Sources/SpecialErrorPages/SpecialErrorData.swift
@@ -17,24 +17,61 @@
//
import Foundation
+import MaliciousSiteProtection
public enum SpecialErrorKind: String, Encodable {
case ssl
case phishing
+ // case malware
}
-public struct SpecialErrorData: Encodable, Equatable {
+public enum SpecialErrorData: Encodable, Equatable {
- var kind: SpecialErrorKind
- var errorType: String?
- var domain: String?
- var eTldPlus1: String?
+ enum CodingKeys: CodingKey {
+ case kind
+ case errorType
+ case domain
+ case eTldPlus1
+ case url
+ }
+
+ case ssl(type: SSLErrorType, domain: String, eTldPlus1: String?)
+ case maliciousSite(kind: MaliciousSiteProtection.ThreatKind, url: URL)
+
+ public func encode(to encoder: any Encoder) throws {
+ var container = encoder.container(keyedBy: CodingKeys.self)
+
+ switch self {
+ case .ssl(type: let type, domain: let domain, eTldPlus1: let eTldPlus1):
+ try container.encode(SpecialErrorKind.ssl, forKey: .kind)
+ try container.encode(type, forKey: .errorType)
+ try container.encode(domain, forKey: .domain)
- public init(kind: SpecialErrorKind, errorType: String? = nil, domain: String? = nil, eTldPlus1: String? = nil) {
- self.kind = kind
- self.errorType = errorType
- self.domain = domain
- self.eTldPlus1 = eTldPlus1
+ switch type {
+ case .expired, .selfSigned, .invalid: break
+ case .wrongHost:
+ guard let eTldPlus1 else {
+ assertionFailure("expected eTldPlus1 != nil when kind is .wrongHost")
+ break
+ }
+ try container.encode(eTldPlus1, forKey: .eTldPlus1)
+ }
+
+ case .maliciousSite(kind: let kind, url: let url):
+ // https://app.asana.com/0/1206594217596623/1208824527069247/f
+ var container = encoder.container(keyedBy: CodingKeys.self)
+ try container.encode(kind.errorPageKind, forKey: .kind)
+ try container.encode(url, forKey: .url)
+ }
}
}
+
+public extension MaliciousSiteProtection.ThreatKind {
+ var errorPageKind: SpecialErrorKind {
+ switch self {
+ // case .malware: .malware
+ case .phishing: .phishing
+ }
+ }
+}
diff --git a/Sources/SpecialErrorPages/SpecialErrorPageUserScript.swift b/Sources/SpecialErrorPages/SpecialErrorPageUserScript.swift
index 71cad64e8..f131358ef 100644
--- a/Sources/SpecialErrorPages/SpecialErrorPageUserScript.swift
+++ b/Sources/SpecialErrorPages/SpecialErrorPageUserScript.swift
@@ -23,11 +23,11 @@ import Common
public protocol SpecialErrorPageUserScriptDelegate: AnyObject {
- var errorData: SpecialErrorData? { get }
+ @MainActor var errorData: SpecialErrorData? { get }
- func leaveSite()
- func visitSite()
- func advancedInfoPresented()
+ @MainActor func leaveSiteAction()
+ @MainActor func visitSiteAction()
+ @MainActor func advancedInfoPresented()
}
@@ -105,13 +105,13 @@ public final class SpecialErrorPageUserScript: NSObject, Subfeature {
@MainActor
func handleLeaveSiteAction(params: Any, message: UserScriptMessage) -> Encodable? {
- delegate?.leaveSite()
+ delegate?.leaveSiteAction()
return nil
}
@MainActor
func handleVisitSiteAction(params: Any, message: UserScriptMessage) -> Encodable? {
- delegate?.visitSite()
+ delegate?.visitSiteAction()
return nil
}
diff --git a/Sources/Subscription/API/Model/Entitlement.swift b/Sources/Subscription/API/Model/Entitlement.swift
index c90e7342c..1d8eb645a 100644
--- a/Sources/Subscription/API/Model/Entitlement.swift
+++ b/Sources/Subscription/API/Model/Entitlement.swift
@@ -25,6 +25,7 @@ public struct Entitlement: Codable, Equatable {
case networkProtection = "Network Protection"
case dataBrokerProtection = "Data Broker Protection"
case identityTheftRestoration = "Identity Theft Restoration"
+ case identityTheftRestorationGlobal = "Global Identity Theft Restoration"
case unknown
public init(from decoder: Decoder) throws {
diff --git a/Sources/Subscription/API/SubscriptionEndpointService.swift b/Sources/Subscription/API/SubscriptionEndpointService.swift
index 7898bbddb..552c28d4a 100644
--- a/Sources/Subscription/API/SubscriptionEndpointService.swift
+++ b/Sources/Subscription/API/SubscriptionEndpointService.swift
@@ -27,6 +27,10 @@ public struct GetProductsItem: Decodable {
public let currency: String
}
+public struct GetSubscriptionFeaturesResponse: Decodable {
+ public let features: [Entitlement.ProductName]
+}
+
public struct GetCustomerPortalURLResponse: Decodable {
public let customerPortalUrl: String
}
@@ -47,6 +51,7 @@ public protocol SubscriptionEndpointService {
func getSubscription(accessToken: String, cachePolicy: APICachePolicy) async -> Result
func signOut()
func getProducts() async -> Result<[GetProductsItem], APIServiceError>
+ func getSubscriptionFeatures(for subscriptionID: String) async -> Result
func getCustomerPortalURL(accessToken: String, externalID: String) async -> Result
func confirmPurchase(accessToken: String, signature: String) async -> Result
}
@@ -137,6 +142,12 @@ public struct DefaultSubscriptionEndpointService: SubscriptionEndpointService {
// MARK: -
+ public func getSubscriptionFeatures(for subscriptionID: String) async -> Result {
+ await apiService.executeAPICall(method: "GET", endpoint: "products/\(subscriptionID)/features", headers: nil, body: nil)
+ }
+
+ // MARK: -
+
public func getCustomerPortalURL(accessToken: String, externalID: String) async -> Result {
var headers = apiService.makeAuthorizationHeader(for: accessToken)
headers["externalAccountId"] = externalID
diff --git a/Tests/PhishingDetectionTests/Mocks/PhishingDetectorMock.swift b/Sources/Subscription/FeatureFlags/FeatureFlaggerMapping.swift
similarity index 53%
rename from Tests/PhishingDetectionTests/Mocks/PhishingDetectorMock.swift
rename to Sources/Subscription/FeatureFlags/FeatureFlaggerMapping.swift
index 4a56474e0..f2541131e 100644
--- a/Tests/PhishingDetectionTests/Mocks/PhishingDetectorMock.swift
+++ b/Sources/Subscription/FeatureFlags/FeatureFlaggerMapping.swift
@@ -1,5 +1,5 @@
//
-// PhishingDetectorMock.swift
+// FeatureFlaggerMapping.swift
//
// Copyright © 2024 DuckDuckGo. All rights reserved.
//
@@ -17,22 +17,17 @@
//
import Foundation
-import PhishingDetection
-public class MockPhishingDetector: PhishingDetecting {
- private var mockClient: PhishingDetectionClientProtocol
- public var didCallIsMalicious: Bool = false
+open class FeatureFlaggerMapping {
+ public typealias Mapping = (_ feature: Feature) -> Bool
- init() {
- self.mockClient = MockPhishingDetectionClient()
- }
+ private let isFeatureEnabledMapping: Mapping
- public func getMatches(hashPrefix: String) async -> Set {
- let matches = await mockClient.getMatches(hashPrefix: hashPrefix)
- return Set(matches)
+ public init(mapping: @escaping Mapping) {
+ isFeatureEnabledMapping = mapping
}
- public func isMalicious(url: URL) async -> Bool {
- return url.absoluteString.contains("malicious")
+ public func isFeatureOn(_ feature: Feature) -> Bool {
+ return isFeatureEnabledMapping(feature)
}
}
diff --git a/Sources/Subscription/Flows/Models/SubscriptionEnvironmentNames.swift b/Sources/Subscription/FeatureFlags/SubscriptionFeatureFlags.swift
similarity index 54%
rename from Sources/Subscription/Flows/Models/SubscriptionEnvironmentNames.swift
rename to Sources/Subscription/FeatureFlags/SubscriptionFeatureFlags.swift
index 540c0c2e8..104d20066 100644
--- a/Sources/Subscription/Flows/Models/SubscriptionEnvironmentNames.swift
+++ b/Sources/Subscription/FeatureFlags/SubscriptionFeatureFlags.swift
@@ -1,5 +1,5 @@
//
-// SubscriptionEnvironmentNames.swift
+// SubscriptionFeatureFlags.swift
//
// Copyright © 2024 DuckDuckGo. All rights reserved.
//
@@ -18,18 +18,21 @@
import Foundation
-public enum SubscriptionFeatureName: String, CaseIterable {
- case privateBrowsing = "private-browsing"
- case privateSearch = "private-search"
- case emailProtection = "email-protection"
- case appTrackingProtection = "app-tracking-protection"
- case vpn = "vpn"
- case personalInformationRemoval = "personal-information-removal"
- case identityTheftRestoration = "identity-theft-restoration"
+public enum SubscriptionFeatureFlags {
+ case isLaunchedROW
+ case isLaunchedROWOverride
+ case usePrivacyProUSARegionOverride
+ case usePrivacyProROWRegionOverride
}
-public enum SubscriptionPlatformName: String {
- case ios
- case macos
- case stripe
+public extension SubscriptionFeatureFlags {
+
+ var defaultState: Bool {
+ switch self {
+ case .isLaunchedROW, .isLaunchedROWOverride:
+ return true
+ case .usePrivacyProUSARegionOverride, .usePrivacyProROWRegionOverride:
+ return false
+ }
+ }
}
diff --git a/Sources/Subscription/Flows/Models/SubscriptionOptions.swift b/Sources/Subscription/Flows/Models/SubscriptionOptions.swift
index ca0347384..5544cf680 100644
--- a/Sources/Subscription/Flows/Models/SubscriptionOptions.swift
+++ b/Sources/Subscription/Flows/Models/SubscriptionOptions.swift
@@ -19,21 +19,34 @@
import Foundation
public struct SubscriptionOptions: Encodable, Equatable {
- let platform: String
+ let platform: SubscriptionPlatformName
let options: [SubscriptionOption]
let features: [SubscriptionFeature]
+
public static var empty: SubscriptionOptions {
- let features = SubscriptionFeatureName.allCases.map { SubscriptionFeature(name: $0.rawValue) }
+ let features = [SubscriptionFeature(name: .networkProtection),
+ SubscriptionFeature(name: .dataBrokerProtection),
+ SubscriptionFeature(name: .identityTheftRestoration)]
let platform: SubscriptionPlatformName
#if os(iOS)
platform = .ios
#else
platform = .macos
#endif
- return SubscriptionOptions(platform: platform.rawValue, options: [], features: features)
+ return SubscriptionOptions(platform: platform, options: [], features: features)
+ }
+
+ public func withoutPurchaseOptions() -> Self {
+ SubscriptionOptions(platform: platform, options: [], features: features)
}
}
+public enum SubscriptionPlatformName: String, Encodable {
+ case ios
+ case macos
+ case stripe
+}
+
public struct SubscriptionOption: Encodable, Equatable {
let id: String
let cost: SubscriptionOptionCost
@@ -45,5 +58,5 @@ struct SubscriptionOptionCost: Encodable, Equatable {
}
public struct SubscriptionFeature: Encodable, Equatable {
- let name: String
+ let name: Entitlement.ProductName
}
diff --git a/Sources/Subscription/Flows/Stripe/StripePurchaseFlow.swift b/Sources/Subscription/Flows/Stripe/StripePurchaseFlow.swift
index 3912d96a2..43e0448e7 100644
--- a/Sources/Subscription/Flows/Stripe/StripePurchaseFlow.swift
+++ b/Sources/Subscription/Flows/Stripe/StripePurchaseFlow.swift
@@ -71,9 +71,11 @@ public final class DefaultStripePurchaseFlow: StripePurchaseFlow {
cost: cost)
}
- let features = SubscriptionFeatureName.allCases.map { SubscriptionFeature(name: $0.rawValue) }
+ let features = [SubscriptionFeature(name: .networkProtection),
+ SubscriptionFeature(name: .dataBrokerProtection),
+ SubscriptionFeature(name: .identityTheftRestoration)]
- return .success(SubscriptionOptions(platform: SubscriptionPlatformName.stripe.rawValue,
+ return .success(SubscriptionOptions(platform: SubscriptionPlatformName.stripe,
options: options,
features: features))
}
diff --git a/Sources/Subscription/Managers/StorePurchaseManager.swift b/Sources/Subscription/Managers/StorePurchaseManager.swift
index 24ebc3d31..47791eb22 100644
--- a/Sources/Subscription/Managers/StorePurchaseManager.swift
+++ b/Sources/Subscription/Managers/StorePurchaseManager.swift
@@ -41,6 +41,7 @@ public protocol StorePurchaseManager {
var purchasedProductIDs: [String] { get }
var purchaseQueue: [String] { get }
var areProductsAvailable: Bool { get }
+ var currentStorefrontRegion: SubscriptionRegion { get }
@MainActor func syncAppleIDAccount() async throws
@MainActor func updateAvailableProducts() async
@@ -56,21 +57,24 @@ public protocol StorePurchaseManager {
@available(macOS 12.0, iOS 15.0, *)
public final class DefaultStorePurchaseManager: ObservableObject, StorePurchaseManager {
- let productIdentifiers = ["ios.subscription.1month", "ios.subscription.1year",
- "subscription.1month", "subscription.1year",
- "review.subscription.1month", "review.subscription.1year",
- "tf.sandbox.subscription.1month", "tf.sandbox.subscription.1year",
- "ddg.privacy.pro.monthly.renews.us", "ddg.privacy.pro.yearly.renews.us"]
+ private let storeSubscriptionConfiguration: StoreSubscriptionConfiguration
+ private let subscriptionFeatureMappingCache: SubscriptionFeatureMappingCache
+ private let subscriptionFeatureFlagger: FeatureFlaggerMapping?
@Published public private(set) var availableProducts: [Product] = []
@Published public private(set) var purchasedProductIDs: [String] = []
@Published public private(set) var purchaseQueue: [String] = []
public var areProductsAvailable: Bool { !availableProducts.isEmpty }
+ public private(set) var currentStorefrontRegion: SubscriptionRegion = .usa
private var transactionUpdates: Task?
private var storefrontChanges: Task?
- public init() {
+ public init(subscriptionFeatureMappingCache: SubscriptionFeatureMappingCache,
+ subscriptionFeatureFlagger: FeatureFlaggerMapping? = nil) {
+ self.storeSubscriptionConfiguration = DefaultStoreSubscriptionConfiguration()
+ self.subscriptionFeatureMappingCache = subscriptionFeatureMappingCache
+ self.subscriptionFeatureFlagger = subscriptionFeatureFlagger
transactionUpdates = observeTransactionUpdates()
storefrontChanges = observeStorefrontChanges()
}
@@ -109,17 +113,29 @@ public final class DefaultStorePurchaseManager: ObservableObject, StorePurchaseM
return nil
}
- let options = [SubscriptionOption(id: monthly.id, cost: .init(displayPrice: monthly.displayPrice, recurrence: "monthly")),
- SubscriptionOption(id: yearly.id, cost: .init(displayPrice: yearly.displayPrice, recurrence: "yearly"))]
- let features = SubscriptionFeatureName.allCases.map { SubscriptionFeature(name: $0.rawValue) }
- let platform: SubscriptionPlatformName
-
+ let platform: SubscriptionPlatformName = {
#if os(iOS)
- platform = .ios
+ .ios
#else
- platform = .macos
+ .macos
#endif
- return SubscriptionOptions(platform: platform.rawValue,
+ }()
+
+ let options = [SubscriptionOption(id: monthly.id,
+ cost: .init(displayPrice: monthly.displayPrice, recurrence: "monthly")),
+ SubscriptionOption(id: yearly.id,
+ cost: .init(displayPrice: yearly.displayPrice, recurrence: "yearly"))]
+
+ let features: [SubscriptionFeature]
+
+ if let featureFlagger = subscriptionFeatureFlagger, featureFlagger.isFeatureOn(.isLaunchedROW) || featureFlagger.isFeatureOn(.isLaunchedROWOverride) {
+ features = await subscriptionFeatureMappingCache.subscriptionFeatures(for: monthly.id).compactMap { SubscriptionFeature(name: $0) }
+ } else {
+ let allFeatures: [Entitlement.ProductName] = [.networkProtection, .dataBrokerProtection, .identityTheftRestoration]
+ features = allFeatures.compactMap { SubscriptionFeature(name: $0) }
+ }
+
+ return SubscriptionOptions(platform: platform,
options: options,
features: features)
}
@@ -129,11 +145,36 @@ public final class DefaultStorePurchaseManager: ObservableObject, StorePurchaseM
Logger.subscription.info("[StorePurchaseManager] updateAvailableProducts")
do {
- let availableProducts = try await Product.products(for: productIdentifiers)
- Logger.subscription.info("[StorePurchaseManager] updateAvailableProducts fetched \(availableProducts.count) products")
+ let storefrontCountryCode: String?
+ let storefrontRegion: SubscriptionRegion
+
+ if let featureFlagger = subscriptionFeatureFlagger, featureFlagger.isFeatureOn(.isLaunchedROW) || featureFlagger.isFeatureOn(.isLaunchedROWOverride) {
+ if let subscriptionFeatureFlagger, subscriptionFeatureFlagger.isFeatureOn(.usePrivacyProUSARegionOverride) {
+ storefrontCountryCode = "USA"
+ } else if let subscriptionFeatureFlagger, subscriptionFeatureFlagger.isFeatureOn(.usePrivacyProROWRegionOverride) {
+ storefrontCountryCode = "POL"
+ } else {
+ storefrontCountryCode = await Storefront.current?.countryCode
+ }
+
+ storefrontRegion = SubscriptionRegion.matchingRegion(for: storefrontCountryCode ?? "USA") ?? .usa // Fallback to USA
+ } else {
+ storefrontCountryCode = "USA"
+ storefrontRegion = .usa
+ }
+
+ self.currentStorefrontRegion = storefrontRegion
+ let applicableProductIdentifiers = storeSubscriptionConfiguration.subscriptionIdentifiers(for: storefrontRegion)
+ let availableProducts = try await Product.products(for: applicableProductIdentifiers)
+ Logger.subscription.info("[StorePurchaseManager] updateAvailableProducts fetched \(availableProducts.count) products for \(storefrontCountryCode ?? "", privacy: .public)")
if self.availableProducts != availableProducts {
self.availableProducts = availableProducts
+
+ // Update cached subscription features mapping
+ for id in availableProducts.compactMap({ $0.id }) {
+ _ = await subscriptionFeatureMappingCache.subscriptionFeatures(for: id)
+ }
}
} catch {
Logger.subscription.error("[StorePurchaseManager] Error: \(String(reflecting: error), privacy: .public)")
@@ -295,3 +336,36 @@ public final class DefaultStorePurchaseManager: ObservableObject, StorePurchaseM
}
}
}
+
+public extension UserDefaults {
+
+ enum Constants {
+ static let storefrontRegionOverrideKey = "Subscription.debug.storefrontRegionOverride"
+ static let usaValue = "usa"
+ static let rowValue = "row"
+ }
+
+ dynamic var storefrontRegionOverride: SubscriptionRegion? {
+ get {
+ switch string(forKey: Constants.storefrontRegionOverrideKey) {
+ case "usa":
+ return .usa
+ case "row":
+ return .restOfWorld
+ default:
+ return nil
+ }
+ }
+
+ set {
+ switch newValue {
+ case .usa:
+ set(Constants.usaValue, forKey: Constants.storefrontRegionOverrideKey)
+ case .restOfWorld:
+ set(Constants.rowValue, forKey: Constants.storefrontRegionOverrideKey)
+ default:
+ removeObject(forKey: Constants.storefrontRegionOverrideKey)
+ }
+ }
+ }
+}
diff --git a/Sources/Subscription/Managers/SubscriptionManager.swift b/Sources/Subscription/Managers/SubscriptionManager.swift
index ef8abfe85..cac861106 100644
--- a/Sources/Subscription/Managers/SubscriptionManager.swift
+++ b/Sources/Subscription/Managers/SubscriptionManager.swift
@@ -24,6 +24,7 @@ public protocol SubscriptionManager {
var accountManager: AccountManager { get }
var subscriptionEndpointService: SubscriptionEndpointService { get }
var authEndpointService: AuthEndpointService { get }
+ var subscriptionFeatureMappingCache: SubscriptionFeatureMappingCache { get }
// Environment
static func loadEnvironmentFrom(userDefaults: UserDefaults) -> SubscriptionEnvironment?
@@ -35,6 +36,7 @@ public protocol SubscriptionManager {
func loadInitialData()
func refreshCachedSubscriptionAndEntitlements(completion: @escaping (_ isSubscriptionActive: Bool) -> Void)
func url(for type: SubscriptionURL) -> URL
+ func currentSubscriptionFeatures() async -> [Entitlement.ProductName]
}
/// Single entry point for everything related to Subscription. This manager is disposable, every time something related to the environment changes this need to be recreated.
@@ -43,19 +45,26 @@ public final class DefaultSubscriptionManager: SubscriptionManager {
public let accountManager: AccountManager
public let subscriptionEndpointService: SubscriptionEndpointService
public let authEndpointService: AuthEndpointService
+ public let subscriptionFeatureMappingCache: SubscriptionFeatureMappingCache
public let currentEnvironment: SubscriptionEnvironment
- public private(set) var canPurchase: Bool = false
+
+ private let subscriptionFeatureFlagger: FeatureFlaggerMapping
public init(storePurchaseManager: StorePurchaseManager? = nil,
accountManager: AccountManager,
subscriptionEndpointService: SubscriptionEndpointService,
authEndpointService: AuthEndpointService,
- subscriptionEnvironment: SubscriptionEnvironment) {
+ subscriptionFeatureMappingCache: SubscriptionFeatureMappingCache,
+ subscriptionEnvironment: SubscriptionEnvironment,
+ subscriptionFeatureFlagger: FeatureFlaggerMapping) {
self._storePurchaseManager = storePurchaseManager
self.accountManager = accountManager
self.subscriptionEndpointService = subscriptionEndpointService
self.authEndpointService = authEndpointService
+ self.subscriptionFeatureMappingCache = subscriptionFeatureMappingCache
self.currentEnvironment = subscriptionEnvironment
+ self.subscriptionFeatureFlagger = subscriptionFeatureFlagger
+
switch currentEnvironment.purchasePlatform {
case .appStore:
if #available(macOS 12.0, iOS 15.0, *) {
@@ -68,6 +77,21 @@ public final class DefaultSubscriptionManager: SubscriptionManager {
}
}
+ public var canPurchase: Bool {
+ guard let storePurchaseManager = _storePurchaseManager else { return false }
+
+ switch storePurchaseManager.currentStorefrontRegion {
+ case .usa:
+ return storePurchaseManager.areProductsAvailable
+ case .restOfWorld:
+ if subscriptionFeatureFlagger.isFeatureOn(.isLaunchedROW) || subscriptionFeatureFlagger.isFeatureOn(.isLaunchedROWOverride) {
+ return storePurchaseManager.areProductsAvailable
+ } else {
+ return false
+ }
+ }
+ }
+
@available(macOS 12.0, iOS 15.0, *)
public func storePurchaseManager() -> StorePurchaseManager {
return _storePurchaseManager!
@@ -98,7 +122,6 @@ public final class DefaultSubscriptionManager: SubscriptionManager {
@available(macOS 12.0, iOS 15.0, *) private func setupForAppStore() {
Task {
await storePurchaseManager().updateAvailableProducts()
- canPurchase = storePurchaseManager().areProductsAvailable
}
}
@@ -147,4 +170,21 @@ public final class DefaultSubscriptionManager: SubscriptionManager {
public func url(for type: SubscriptionURL) -> URL {
type.subscriptionURL(environment: currentEnvironment.serviceEnvironment)
}
+
+ // MARK: - Current subscription's features
+
+ public func currentSubscriptionFeatures() async -> [Entitlement.ProductName] {
+ guard let token = accountManager.accessToken else { return [] }
+
+ if subscriptionFeatureFlagger.isFeatureOn(.isLaunchedROW) || subscriptionFeatureFlagger.isFeatureOn(.isLaunchedROWOverride) {
+ switch await subscriptionEndpointService.getSubscription(accessToken: token, cachePolicy: .returnCacheDataElseLoad) {
+ case .success(let subscription):
+ return await subscriptionFeatureMappingCache.subscriptionFeatures(for: subscription.productId)
+ case .failure:
+ return []
+ }
+ } else {
+ return [.networkProtection, .dataBrokerProtection, .identityTheftRestoration]
+ }
+ }
}
diff --git a/Sources/Subscription/NSNotificationName+Subscription.swift b/Sources/Subscription/NSNotificationName+Subscription.swift
index 1079c6231..b674aaff7 100644
--- a/Sources/Subscription/NSNotificationName+Subscription.swift
+++ b/Sources/Subscription/NSNotificationName+Subscription.swift
@@ -20,10 +20,6 @@ import Foundation
public extension NSNotification.Name {
- static let openPrivateBrowsing = Notification.Name("com.duckduckgo.subscription.open.private-browsing")
- static let openPrivateSearch = Notification.Name("com.duckduckgo.subscription.open.private-search")
- static let openEmailProtection = Notification.Name("com.duckduckgo.subscription.open.email-protection")
- static let openAppTrackingProtection = Notification.Name("com.duckduckgo.subscription.open.app-tracking-protection")
static let openVPN = Notification.Name("com.duckduckgo.subscription.open.vpn")
static let openPersonalInformationRemoval = Notification.Name("com.duckduckgo.subscription.open.personal-information-removal")
static let openIdentityTheftRestoration = Notification.Name("com.duckduckgo.subscription.open.identity-theft-restoration")
diff --git a/Sources/Subscription/StoreSubscriptionConfiguration.swift b/Sources/Subscription/StoreSubscriptionConfiguration.swift
new file mode 100644
index 000000000..8dacf7db4
--- /dev/null
+++ b/Sources/Subscription/StoreSubscriptionConfiguration.swift
@@ -0,0 +1,138 @@
+//
+// StoreSubscriptionConfiguration.swift
+//
+// Copyright © 2023 DuckDuckGo. All rights reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//
+
+import Foundation
+import Combine
+
+protocol StoreSubscriptionConfiguration {
+ var allSubscriptionIdentifiers: [String] { get }
+ func subscriptionIdentifiers(for country: String) -> [String]
+ func subscriptionIdentifiers(for region: SubscriptionRegion) -> [String]
+}
+
+final class DefaultStoreSubscriptionConfiguration: StoreSubscriptionConfiguration {
+
+ private let subscriptions: [StoreSubscriptionDefinition]
+
+ convenience init() {
+ self.init(subscriptionDefinitions: [
+ // Production shared for iOS and macOS
+ .init(name: "DuckDuckGo Private Browser",
+ appIdentifier: "com.duckduckgo.mobile.ios",
+ environment: .production,
+ identifiersByRegion: [.usa: ["ddg.privacy.pro.monthly.renews.us",
+ "ddg.privacy.pro.yearly.renews.us"],
+ .restOfWorld: ["ddg.privacy.pro.monthly.renews.row",
+ "ddg.privacy.pro.yearly.renews.row"]]),
+ // iOS debug Alpha build
+ .init(name: "DuckDuckGo Alpha",
+ appIdentifier: "com.duckduckgo.mobile.ios.alpha",
+ environment: .staging,
+ identifiersByRegion: [.usa: ["ios.subscription.1month",
+ "ios.subscription.1year"],
+ .restOfWorld: ["ios.subscription.1month.row",
+ "ios.subscription.1year.row"]]),
+ // macOS debug build
+ .init(name: "IAP debug - DDG for macOS",
+ appIdentifier: "com.duckduckgo.macos.browser.debug",
+ environment: .staging,
+ identifiersByRegion: [.usa: ["subscription.1month",
+ "subscription.1year"],
+ .restOfWorld: ["subscription.1month.row",
+ "subscription.1year.row"]]),
+ // macOS review build
+ .init(name: "IAP review - DDG for macOS",
+ appIdentifier: "com.duckduckgo.macos.browser.review",
+ environment: .staging,
+ identifiersByRegion: [.usa: ["review.subscription.1month",
+ "review.subscription.1year"],
+ .restOfWorld: ["review.subscription.1month.row",
+ "review.subscription.1year.row"]]),
+
+ // macOS TestFlight build
+ .init(name: "DuckDuckGo Sandbox Review",
+ appIdentifier: "com.duckduckgo.mobile.ios.review",
+ environment: .staging,
+ identifiersByRegion: [.usa: ["tf.sandbox.subscription.1month",
+ "tf.sandbox.subscription.1year"],
+ .restOfWorld: ["tf.sandbox.subscription.1month.row",
+ "tf.sandbox.subscription.1year.row"]])
+ ])
+ }
+
+ init(subscriptionDefinitions: [StoreSubscriptionDefinition]) {
+ self.subscriptions = subscriptionDefinitions
+ }
+
+ var allSubscriptionIdentifiers: [String] {
+ subscriptions.reduce([], { $0 + $1.allIdentifiers() })
+ }
+
+ func subscriptionIdentifiers(for country: String) -> [String] {
+ subscriptions.reduce([], { $0 + $1.identifiers(for: country) })
+ }
+
+ func subscriptionIdentifiers(for region: SubscriptionRegion) -> [String] {
+ subscriptions.reduce([], { $0 + $1.identifiers(for: region) })
+ }
+}
+
+struct StoreSubscriptionDefinition {
+ var name: String
+ var appIdentifier: String
+ var environment: SubscriptionEnvironment.ServiceEnvironment
+ var identifiersByRegion: [SubscriptionRegion: [String]]
+
+ func allIdentifiers() -> [String] {
+ identifiersByRegion.values.flatMap { $0 }
+ }
+
+ func identifiers(for country: String) -> [String] {
+ identifiersByRegion.filter { region, _ in region.contains(country) }.flatMap { _, identifiers in identifiers }
+ }
+
+ func identifiers(for region: SubscriptionRegion) -> [String] {
+ identifiersByRegion[region] ?? []
+ }
+}
+
+public enum SubscriptionRegion: CaseIterable {
+ case usa
+ case restOfWorld
+
+ /// Country codes as used by StoreKit, in the ISO 3166-1 Alpha-3 country code representation
+ /// For .restOfWorld definiton see https://app.asana.com/0/1208524871249522/1208571752166956/f
+ var countryCodes: Set {
+ switch self {
+ case .usa:
+ return Set(["USA"])
+ case .restOfWorld:
+ return Set(["CAN", "GBR", "AUT", "DEU", "NLD", "POL", "SWE",
+ "BEL", "BGR", "HRV ", "CYP", "CZE", "DNK", "EST", "FIN", "FRA", "GRC", "HUN", "IRL", "ITA", "LVA", "LTU", "LUX", "MLT", "PRT",
+ "ROU", "SVK", "SVN", "ESP"])
+ }
+ }
+
+ func contains(_ country: String) -> Bool {
+ countryCodes.contains(country.uppercased())
+ }
+
+ static func matchingRegion(for countryCode: String) -> Self? {
+ Self.allCases.first { $0.countryCodes.contains(countryCode) }
+ }
+}
diff --git a/Sources/Subscription/SubscriptionCookie/HTTPCookieStore.swift b/Sources/Subscription/SubscriptionCookie/HTTPCookieStore.swift
index 74fec5568..3b561c637 100644
--- a/Sources/Subscription/SubscriptionCookie/HTTPCookieStore.swift
+++ b/Sources/Subscription/SubscriptionCookie/HTTPCookieStore.swift
@@ -25,4 +25,36 @@ public protocol HTTPCookieStore {
func deleteCookie(_ cookie: HTTPCookie) async
}
-extension WKHTTPCookieStore: HTTPCookieStore {}
+@MainActor
+public struct WKHTTPCookieStoreWrapper: HTTPCookieStore {
+
+ let store: WKHTTPCookieStore
+
+ public init(store: WKHTTPCookieStore) {
+ self.store = store
+ }
+
+ public func allCookies() async -> [HTTPCookie] {
+ return await withCheckedContinuation { continuation in
+ store.getAllCookies { cookies in
+ continuation.resume(returning: cookies)
+ }
+ }
+ }
+
+ public func setCookie(_ cookie: HTTPCookie) async {
+ await withCheckedContinuation { continuation in
+ store.setCookie(cookie) {
+ continuation.resume()
+ }
+ }
+ }
+
+ public func deleteCookie(_ cookie: HTTPCookie) async {
+ await withCheckedContinuation { continuation in
+ store.delete(cookie) {
+ continuation.resume()
+ }
+ }
+ }
+}
diff --git a/Sources/Subscription/SubscriptionFeatureMappingCache.swift b/Sources/Subscription/SubscriptionFeatureMappingCache.swift
new file mode 100644
index 000000000..414fec1d7
--- /dev/null
+++ b/Sources/Subscription/SubscriptionFeatureMappingCache.swift
@@ -0,0 +1,136 @@
+//
+// SubscriptionFeatureMappingCache.swift
+//
+// Copyright © 2023 DuckDuckGo. All rights reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//
+
+import Foundation
+import os.log
+
+public protocol SubscriptionFeatureMappingCache {
+ func subscriptionFeatures(for subscriptionIdentifier: String) async -> [Entitlement.ProductName]
+}
+
+public final class DefaultSubscriptionFeatureMappingCache: SubscriptionFeatureMappingCache {
+
+ private let subscriptionEndpointService: SubscriptionEndpointService
+ private let userDefaults: UserDefaults
+
+ private var subscriptionFeatureMapping: SubscriptionFeatureMapping?
+
+ public init(subscriptionEndpointService: SubscriptionEndpointService, userDefaults: UserDefaults) {
+ self.subscriptionEndpointService = subscriptionEndpointService
+ self.userDefaults = userDefaults
+ }
+
+ public func subscriptionFeatures(for subscriptionIdentifier: String) async -> [Entitlement.ProductName] {
+ Logger.subscription.debug("[SubscriptionFeatureMappingCache] \(#function) \(subscriptionIdentifier)")
+ let features: [Entitlement.ProductName]
+
+ if let subscriptionFeatures = currentSubscriptionFeatureMapping[subscriptionIdentifier] {
+ Logger.subscription.debug("[SubscriptionFeatureMappingCache] - got cached features")
+ features = subscriptionFeatures
+ } else if let subscriptionFeatures = await fetchRemoteFeatures(for: subscriptionIdentifier) {
+ Logger.subscription.debug("[SubscriptionFeatureMappingCache] - fetching features from BE API")
+ features = subscriptionFeatures
+ updateCachedFeatureMapping(with: subscriptionFeatures, for: subscriptionIdentifier)
+ } else {
+ Logger.subscription.debug("[SubscriptionFeatureMappingCache] - Error: using fallback")
+ features = fallbackFeatures
+ }
+
+ return features
+ }
+
+ // MARK: - Current feature mapping
+
+ private var currentSubscriptionFeatureMapping: SubscriptionFeatureMapping {
+ Logger.subscription.debug("[SubscriptionFeatureMappingCache] - \(#function)")
+ let featureMapping: SubscriptionFeatureMapping
+
+ if let cachedFeatureMapping {
+ Logger.subscription.debug("[SubscriptionFeatureMappingCache] -- got cachedFeatureMapping")
+ featureMapping = cachedFeatureMapping
+ } else if let storedFeatureMapping {
+ Logger.subscription.debug("[SubscriptionFeatureMappingCache] -- have to fetchStoredFeatureMapping")
+ featureMapping = storedFeatureMapping
+ updateCachedFeatureMapping(to: featureMapping)
+ } else {
+ Logger.subscription.debug("[SubscriptionFeatureMappingCache] -- so creating a new one!")
+ featureMapping = SubscriptionFeatureMapping()
+ updateCachedFeatureMapping(to: featureMapping)
+ }
+
+ return featureMapping
+ }
+
+ // MARK: - Cached subscription feature mapping
+
+ private var cachedFeatureMapping: SubscriptionFeatureMapping?
+
+ private func updateCachedFeatureMapping(to featureMapping: SubscriptionFeatureMapping) {
+ cachedFeatureMapping = featureMapping
+ }
+
+ private func updateCachedFeatureMapping(with features: [Entitlement.ProductName], for subscriptionIdentifier: String) {
+ var updatedFeatureMapping = cachedFeatureMapping ?? SubscriptionFeatureMapping()
+ updatedFeatureMapping[subscriptionIdentifier] = features
+
+ self.cachedFeatureMapping = updatedFeatureMapping
+ self.storedFeatureMapping = updatedFeatureMapping
+ }
+
+ // MARK: - Stored subscription feature mapping
+
+ static private let subscriptionFeatureMappingKey = "com.duckduckgo.subscription.featuremapping"
+
+ dynamic var storedFeatureMapping: SubscriptionFeatureMapping? {
+ get {
+ guard let data = userDefaults.data(forKey: Self.subscriptionFeatureMappingKey) else { return nil }
+ do {
+ return try JSONDecoder().decode(SubscriptionFeatureMapping?.self, from: data)
+ } catch {
+ assertionFailure("Errored while decoding feature mapping")
+ return nil
+ }
+ }
+
+ set {
+ do {
+ let data = try JSONEncoder().encode(newValue)
+ userDefaults.set(data, forKey: Self.subscriptionFeatureMappingKey)
+ } catch {
+ assertionFailure("Errored while encoding feature mapping")
+ }
+ }
+ }
+
+ // MARK: - Remote subscription feature mapping
+
+ private func fetchRemoteFeatures(for subscriptionIdentifier: String) async -> [Entitlement.ProductName]? {
+ if case let .success(response) = await subscriptionEndpointService.getSubscriptionFeatures(for: subscriptionIdentifier) {
+ Logger.subscription.debug("[SubscriptionFeatureMappingCache] -- Fetched features for `\(subscriptionIdentifier)`: \(response.features)")
+ return response.features
+ }
+
+ return nil
+ }
+
+ // MARK: - Fallback subscription feature mapping
+
+ private let fallbackFeatures: [Entitlement.ProductName] = [.networkProtection, .dataBrokerProtection, .identityTheftRestoration]
+}
+
+typealias SubscriptionFeatureMapping = [String: [Entitlement.ProductName]]
diff --git a/Sources/SubscriptionTestingUtilities/APIs/SubscriptionEndpointServiceMock.swift b/Sources/SubscriptionTestingUtilities/APIs/SubscriptionEndpointServiceMock.swift
index 132685772..b17d585d0 100644
--- a/Sources/SubscriptionTestingUtilities/APIs/SubscriptionEndpointServiceMock.swift
+++ b/Sources/SubscriptionTestingUtilities/APIs/SubscriptionEndpointServiceMock.swift
@@ -22,6 +22,7 @@ import Subscription
public final class SubscriptionEndpointServiceMock: SubscriptionEndpointService {
public var getSubscriptionResult: Result?
public var getProductsResult: Result<[GetProductsItem], APIServiceError>?
+ public var getSubscriptionFeaturesResult: Result?
public var getCustomerPortalURLResult: Result?
public var confirmPurchaseResult: Result?
@@ -55,6 +56,10 @@ public final class SubscriptionEndpointServiceMock: SubscriptionEndpointService
getProductsResult!
}
+ public func getSubscriptionFeatures(for subscriptionID: String) async -> Result {
+ getSubscriptionFeaturesResult!
+ }
+
public func getCustomerPortalURL(accessToken: String, externalID: String) async -> Result {
getCustomerPortalURLResult!
}
diff --git a/Sources/SubscriptionTestingUtilities/Managers/StorePurchaseManagerMock.swift b/Sources/SubscriptionTestingUtilities/Managers/StorePurchaseManagerMock.swift
index 6f654ba72..e326fd720 100644
--- a/Sources/SubscriptionTestingUtilities/Managers/StorePurchaseManagerMock.swift
+++ b/Sources/SubscriptionTestingUtilities/Managers/StorePurchaseManagerMock.swift
@@ -23,6 +23,8 @@ public final class StorePurchaseManagerMock: StorePurchaseManager {
public var purchasedProductIDs: [String] = []
public var purchaseQueue: [String] = []
public var areProductsAvailable: Bool = false
+ public var currentStorefrontRegion: SubscriptionRegion = .usa
+
public var subscriptionOptionsResult: SubscriptionOptions?
public var syncAppleIDAccountResultError: Error?
diff --git a/Sources/SubscriptionTestingUtilities/Managers/SubscriptionManagerMock.swift b/Sources/SubscriptionTestingUtilities/Managers/SubscriptionManagerMock.swift
index 46fdc77e0..3217eedbb 100644
--- a/Sources/SubscriptionTestingUtilities/Managers/SubscriptionManagerMock.swift
+++ b/Sources/SubscriptionTestingUtilities/Managers/SubscriptionManagerMock.swift
@@ -23,6 +23,7 @@ public final class SubscriptionManagerMock: SubscriptionManager {
public var accountManager: AccountManager
public var subscriptionEndpointService: SubscriptionEndpointService
public var authEndpointService: AuthEndpointService
+ public var subscriptionFeatureMappingCache: SubscriptionFeatureMappingCache
public static var storedEnvironment: SubscriptionEnvironment?
public static func loadEnvironmentFrom(userDefaults: UserDefaults) -> SubscriptionEnvironment? {
@@ -52,18 +53,24 @@ public final class SubscriptionManagerMock: SubscriptionManager {
type.subscriptionURL(environment: currentEnvironment.serviceEnvironment)
}
+ public func currentSubscriptionFeatures() async -> [Entitlement.ProductName] {
+ return []
+ }
+
public init(accountManager: AccountManager,
subscriptionEndpointService: SubscriptionEndpointService,
authEndpointService: AuthEndpointService,
storePurchaseManager: StorePurchaseManager,
currentEnvironment: SubscriptionEnvironment,
- canPurchase: Bool) {
+ canPurchase: Bool,
+ subscriptionFeatureMappingCache: SubscriptionFeatureMappingCache) {
self.accountManager = accountManager
self.subscriptionEndpointService = subscriptionEndpointService
self.authEndpointService = authEndpointService
self.internalStorePurchaseManager = storePurchaseManager
self.currentEnvironment = currentEnvironment
self.canPurchase = canPurchase
+ self.subscriptionFeatureMappingCache = subscriptionFeatureMappingCache
}
// MARK: -
diff --git a/Sources/SubscriptionTestingUtilities/SubscriptionCookie/SubscriptionCookieManagerMock.swift b/Sources/SubscriptionTestingUtilities/SubscriptionCookie/SubscriptionCookieManagerMock.swift
index a9166689a..b2a5b8133 100644
--- a/Sources/SubscriptionTestingUtilities/SubscriptionCookie/SubscriptionCookieManagerMock.swift
+++ b/Sources/SubscriptionTestingUtilities/SubscriptionCookie/SubscriptionCookieManagerMock.swift
@@ -29,13 +29,15 @@ public final class SubscriptionCookieManagerMock: SubscriptionCookieManaging {
let subscriptionService = DefaultSubscriptionEndpointService(currentServiceEnvironment: .production)
let authService = DefaultAuthEndpointService(currentServiceEnvironment: .production)
let storePurchaseManager = StorePurchaseManagerMock()
+ let subscriptionFeatureMappingCache = SubscriptionFeatureMappingCacheMock()
let subscriptionManager = SubscriptionManagerMock(accountManager: accountManager,
subscriptionEndpointService: subscriptionService,
authEndpointService: authService,
storePurchaseManager: storePurchaseManager,
currentEnvironment: SubscriptionEnvironment(serviceEnvironment: .production,
purchasePlatform: .appStore),
- canPurchase: true)
+ canPurchase: true,
+ subscriptionFeatureMappingCache: subscriptionFeatureMappingCache)
self.init(subscriptionManager: subscriptionManager,
currentCookieStore: { return nil },
diff --git a/Sources/SubscriptionTestingUtilities/SubscriptionFeatureMappingCacheMock.swift b/Sources/SubscriptionTestingUtilities/SubscriptionFeatureMappingCacheMock.swift
new file mode 100644
index 000000000..c4612a9bd
--- /dev/null
+++ b/Sources/SubscriptionTestingUtilities/SubscriptionFeatureMappingCacheMock.swift
@@ -0,0 +1,31 @@
+//
+// SubscriptionFeatureMappingCacheMock.swift
+//
+// Copyright © 2024 DuckDuckGo. All rights reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//
+
+import Foundation
+import Subscription
+
+public final class SubscriptionFeatureMappingCacheMock: SubscriptionFeatureMappingCache {
+
+ public var mapping: [String: [Entitlement.ProductName]] = [:]
+
+ public init() { }
+
+ public func subscriptionFeatures(for subscriptionIdentifier: String) async -> [Entitlement.ProductName] {
+ return mapping[subscriptionIdentifier] ?? []
+ }
+}
diff --git a/Sources/TestUtils/MockAPIService.swift b/Sources/TestUtils/MockAPIService.swift
index be4bf47a2..f4d35b4b6 100644
--- a/Sources/TestUtils/MockAPIService.swift
+++ b/Sources/TestUtils/MockAPIService.swift
@@ -19,12 +19,16 @@
import Foundation
import Networking
-public struct MockAPIService: APIService {
+public class MockAPIService: APIService {
- public var apiResponse: Result
+ public var requestHandler: ((APIRequestV2) -> Result)!
- public func fetch(request: Networking.APIRequestV2) async throws -> APIResponseV2 {
- switch apiResponse {
+ public init(requestHandler: ((APIRequestV2) -> Result)? = nil) {
+ self.requestHandler = requestHandler
+ }
+
+ public func fetch(request: APIRequestV2) async throws -> APIResponseV2 {
+ switch requestHandler!(request) {
case .success(let result):
return result
case .failure(let error):
diff --git a/Sources/UserScript/UserScript.swift b/Sources/UserScript/UserScript.swift
index 3b35ddc42..728b3b36a 100644
--- a/Sources/UserScript/UserScript.swift
+++ b/Sources/UserScript/UserScript.swift
@@ -107,7 +107,7 @@ extension UserScript {
}
public func makeWKUserScript() async -> WKUserScriptBox {
- let source = (try? await Task.detached { [source] in Self.prepareScriptSource(from: source) }.result.get())!
+ let source = await Task.detached { [source] in Self.prepareScriptSource(from: source) }.result.get()
return await Self.makeWKUserScript(from: source,
injectionTime: injectionTime,
forMainFrameOnly: forMainFrameOnly,
diff --git a/Tests/BrokenSitePromptTests/PrivacyConfigurationManagerMock.swift b/Tests/BrokenSitePromptTests/PrivacyConfigurationManagerMock.swift
index 9a3d7610c..29adc74fb 100644
--- a/Tests/BrokenSitePromptTests/PrivacyConfigurationManagerMock.swift
+++ b/Tests/BrokenSitePromptTests/PrivacyConfigurationManagerMock.swift
@@ -59,7 +59,7 @@ class PrivacyConfigurationMock: PrivacyConfiguration {
return enabledSubfeaturesForVersions[subfeature.rawValue]?.contains(versionProvider.appVersion() ?? "") ?? false
}
- func stateFor(_ subfeature: any PrivacySubfeature, versionProvider: BrowserServicesKit.AppVersionProvider, randomizer: (Range) -> Double) -> BrowserServicesKit.PrivacyConfigurationFeatureState {
+ func stateFor(_ subfeature: any PrivacySubfeature, versionProvider: AppVersionProvider, randomizer: (Range) -> Double) -> PrivacyConfigurationFeatureState {
if isSubfeatureEnabled(subfeature, versionProvider: versionProvider, randomizer: randomizer) {
return .enabled
}
@@ -98,6 +98,18 @@ class PrivacyConfigurationMock: PrivacyConfiguration {
return userUnprotected.contains(domain ?? "")
}
+ func stateFor(subfeatureID: BrowserServicesKit.SubfeatureID, parentFeatureID: BrowserServicesKit.ParentFeatureID, versionProvider: BrowserServicesKit.AppVersionProvider, randomizer: (Range) -> Double) -> BrowserServicesKit.PrivacyConfigurationFeatureState {
+ return .enabled
+ }
+
+ func cohorts(for subfeature: any BrowserServicesKit.PrivacySubfeature) -> [BrowserServicesKit.PrivacyConfigurationData.Cohort]? {
+ return nil
+ }
+
+ func cohorts(subfeatureID: BrowserServicesKit.SubfeatureID, parentFeatureID: BrowserServicesKit.ParentFeatureID) -> [BrowserServicesKit.PrivacyConfigurationData.Cohort]? {
+ return nil
+ }
+
}
class PrivacyConfigurationManagerMock: PrivacyConfigurationManaging {
diff --git a/Tests/BrowserServicesKitTests/ContentBlocker/ContentBlockerRulesManagerInitialCompilationTests.swift b/Tests/BrowserServicesKitTests/ContentBlocker/ContentBlockerRulesManagerInitialCompilationTests.swift
index f6082a7b2..455734d9b 100644
--- a/Tests/BrowserServicesKitTests/ContentBlocker/ContentBlockerRulesManagerInitialCompilationTests.swift
+++ b/Tests/BrowserServicesKitTests/ContentBlocker/ContentBlockerRulesManagerInitialCompilationTests.swift
@@ -24,8 +24,9 @@ import TrackerRadarKit
import BrowserServicesKit
import WebKit
import XCTest
+import Common
-final class CountedFulfillmentTestExpectation: XCTestExpectation {
+final class CountedFulfillmentTestExpectation: XCTestExpectation, @unchecked Sendable {
private(set) var currentFulfillmentCount: Int = 0
override func fulfill() {
@@ -57,12 +58,24 @@ final class ContentBlockerRulesManagerInitialCompilationTests: XCTestCase {
expStore.fulfill()
}
+ let lookupAndFetchExp = self.expectation(description: "LRC should be missing")
+ let errorHandler = EventMapping { event, _, params, _ in
+ if case .contentBlockingLRCMissing = event {
+ lookupAndFetchExp.fulfill()
+ } else if case .contentBlockingCompilationTime = event {
+ XCTAssertNotNil(params?["compilationTime"])
+ } else {
+ XCTFail("Unexpected event: \(event)")
+ }
+ }
+
let cbrm = ContentBlockerRulesManager(rulesSource: mockRulesSource,
exceptionsSource: mockExceptionsSource,
lastCompiledRulesStore: mockLastCompiledRulesStore,
- updateListener: rulesUpdateListener)
+ updateListener: rulesUpdateListener,
+ errorReporting: errorHandler)
- wait(for: [exp, expStore], timeout: 15.0)
+ wait(for: [exp, expStore, lookupAndFetchExp], timeout: 15.0)
XCTAssertNotNil(mockLastCompiledRulesStore.rules)
XCTAssertEqual(mockLastCompiledRulesStore.rules.first?.etag, mockRulesSource.trackerData?.etag)
@@ -93,6 +106,8 @@ final class ContentBlockerRulesManagerInitialCompilationTests: XCTestCase {
XCTFail("Should use rules cached by WebKit")
}
+ let lookupAndFetchExp = self.expectation(description: "Should not fetch LRC")
+
// simulate the rules have been compiled in the past so the WKContentRuleListStore contains it
_ = ContentBlockerRulesManager(rulesSource: mockRulesSource,
exceptionsSource: mockExceptionsSource,
@@ -103,15 +118,28 @@ final class ContentBlockerRulesManagerInitialCompilationTests: XCTestCase {
rulesUpdateListener.onRulesUpdated = { rules in
exp.fulfill()
if exp.currentFulfillmentCount == 1 { // finished compilation after first installation
+ let errorHandler = EventMapping { event, _, params, _ in
+ if case .contentBlockingFetchLRCSucceeded = event {
+ XCTFail("Should not fetch LRC")
+ } else if case .contentBlockingLookupRulesSucceeded = event {
+ lookupAndFetchExp.fulfill()
+ } else if case .contentBlockingCompilationTime = event {
+ XCTAssertNotNil(params?["compilationTime"])
+ } else {
+ XCTFail("Unexpected event: \(event)")
+ }
+ }
+
_ = ContentBlockerRulesManager(rulesSource: mockRulesSource,
exceptionsSource: mockExceptionsSource,
lastCompiledRulesStore: mockLastCompiledRulesStore,
- updateListener: self.rulesUpdateListener)
+ updateListener: self.rulesUpdateListener,
+ errorReporting: errorHandler)
}
assertRules(rules)
}
- wait(for: [exp], timeout: 15.0)
+ wait(for: [exp, lookupAndFetchExp], timeout: 15.0)
func assertRules(_ rules: [ContentBlockerRulesManager.Rules]) {
guard let rules = rules.first else { XCTFail("Couldn't get rules"); return }
@@ -178,12 +206,27 @@ final class ContentBlockerRulesManagerInitialCompilationTests: XCTestCase {
XCTAssertEqual(newListName, rules.name)
}
+ let lookupAndFetchExp = self.expectation(description: "Should not fetch LRC")
+
+ let errorHandler = EventMapping { event, _, params, _ in
+ if case .contentBlockingFetchLRCSucceeded = event {
+ XCTFail("Should not fetch LRC")
+ } else if case .contentBlockingNoMatchInLRC = event {
+ lookupAndFetchExp.fulfill()
+ } else if case .contentBlockingCompilationTime = event {
+ XCTAssertNotNil(params?["compilationTime"])
+ } else {
+ XCTFail("Unexpected event: \(event)")
+ }
+ }
+
_ = ContentBlockerRulesManager(rulesSource: mockRulesSource,
exceptionsSource: mockExceptionsSource,
lastCompiledRulesStore: mockLastCompiledRulesStore,
- updateListener: rulesUpdateListener)
+ updateListener: rulesUpdateListener,
+ errorReporting: errorHandler)
- wait(for: [expCacheLookup, expNext], timeout: 15.0)
+ wait(for: [expCacheLookup, expNext, lookupAndFetchExp], timeout: 15.0)
}
func testInitialCompilation_WhenThereAreChangesToTDS_ShouldBuildRulesUsingLastCompiledRulesAndScheduleRecompilationWithNewSource() {
@@ -220,14 +263,26 @@ final class ContentBlockerRulesManagerInitialCompilationTests: XCTestCase {
exceptionsSource: mockExceptionsSource,
updateListener: rulesUpdateListener)
+ let lookupAndFetchExp = self.expectation(description: "Fetch LRC succeeded")
let expOld = CountedFulfillmentTestExpectation(description: "Old Rules Compiled")
rulesUpdateListener.onRulesUpdated = { _ in
expOld.fulfill()
+ let errorHandler = EventMapping { event, _, params, _ in
+ if case .contentBlockingFetchLRCSucceeded = event {
+ lookupAndFetchExp.fulfill()
+ } else if case .contentBlockingCompilationTime = event {
+ XCTAssertNotNil(params?["compilationTime"])
+ } else {
+ XCTFail("Unexpected event: \(event)")
+ }
+ }
+
_ = ContentBlockerRulesManager(rulesSource: mockUpdatedRulesSource,
exceptionsSource: mockExceptionsSource,
lastCompiledRulesStore: mockLastCompiledRulesStore,
- updateListener: self.rulesUpdateListener)
+ updateListener: self.rulesUpdateListener,
+ errorReporting: errorHandler)
}
wait(for: [expOld], timeout: 15.0)
@@ -237,27 +292,27 @@ final class ContentBlockerRulesManagerInitialCompilationTests: XCTestCase {
expLastCompiledFetched.fulfill()
}
- let expRecompiled = CountedFulfillmentTestExpectation(description: "New Rules Compiled")
- rulesUpdateListener.onRulesUpdated = { _ in
- expRecompiled.fulfill()
- }
+ let expRecompiled = CountedFulfillmentTestExpectation(description: "New Rules Compiled")
+ rulesUpdateListener.onRulesUpdated = { _ in
+ expRecompiled.fulfill()
+
+ if expRecompiled.currentFulfillmentCount == 1 { // finished compilation after cold start (using last compiled rules)
+ mockLastCompiledRulesStore.onRulesGet = {}
+ XCTAssertEqual(mockLastCompiledRulesStore.rules.first?.etag, oldEtag)
+ XCTAssertEqual(mockLastCompiledRulesStore.rules.first?.name, mockRulesSource.ruleListName)
+ XCTAssertEqual(mockLastCompiledRulesStore.rules.first?.trackerData, mockRulesSource.trackerData?.tds)
+ XCTAssertEqual(mockLastCompiledRulesStore.rules.first?.identifier, oldIdentifier)
+ } else if expRecompiled.currentFulfillmentCount == 2 { // finished recompilation of rules due to changed tds
+ XCTAssertEqual(mockLastCompiledRulesStore.rules.first?.etag, updatedEtag)
+ XCTAssertEqual(mockLastCompiledRulesStore.rules.first?.name, mockRulesSource.ruleListName)
+ XCTAssertEqual(mockLastCompiledRulesStore.rules.first?.trackerData, mockRulesSource.trackerData?.tds)
+ XCTAssertEqual(mockLastCompiledRulesStore.rules.first?.identifier, newIdentifier)
+ }
+ }
- if expRecompiled.currentFulfillmentCount == 1 { // finished compilation after cold start (using last compiled rules)
-
- mockLastCompiledRulesStore.onRulesGet = {}
- XCTAssertEqual(mockLastCompiledRulesStore.rules.first?.etag, oldEtag)
- XCTAssertEqual(mockLastCompiledRulesStore.rules.first?.name, mockRulesSource.ruleListName)
- XCTAssertEqual(mockLastCompiledRulesStore.rules.first?.trackerData, mockRulesSource.trackerData?.tds)
- XCTAssertEqual(mockLastCompiledRulesStore.rules.first?.identifier, oldIdentifier)
- } else if expRecompiled.currentFulfillmentCount == 2 { // finished recompilation of rules due to changed tds
- XCTAssertEqual(mockLastCompiledRulesStore.rules.first?.etag, updatedEtag)
- XCTAssertEqual(mockLastCompiledRulesStore.rules.first?.name, mockRulesSource.ruleListName)
- XCTAssertEqual(mockLastCompiledRulesStore.rules.first?.trackerData, mockRulesSource.trackerData?.tds)
- XCTAssertEqual(mockLastCompiledRulesStore.rules.first?.identifier, newIdentifier)
- }
+ wait(for: [expLastCompiledFetched, expRecompiled, lookupAndFetchExp], timeout: 15.0)
- wait(for: [expLastCompiledFetched, expRecompiled], timeout: 15.0)
- }
+ }
struct MockLastCompiledRules: LastCompiledRules {
diff --git a/Tests/BrowserServicesKitTests/ContentBlocker/ContentBlockerRulesManagerTests.swift b/Tests/BrowserServicesKitTests/ContentBlocker/ContentBlockerRulesManagerTests.swift
index 60df0f3a1..15321e1ba 100644
--- a/Tests/BrowserServicesKitTests/ContentBlocker/ContentBlockerRulesManagerTests.swift
+++ b/Tests/BrowserServicesKitTests/ContentBlocker/ContentBlockerRulesManagerTests.swift
@@ -203,6 +203,8 @@ class ContentBlockerRulesManagerLoadingTests: ContentBlockerRulesManagerTests {
let errorExp = expectation(description: "No error reported")
errorExp.isInverted = true
+
+ let lookupAndFetchExp = expectation(description: "Look and Fetch rules failed")
let compilationTimeExp = expectation(description: "Compilation Time reported")
let errorHandler = EventMapping { event, _, params, _ in
if case .contentBlockingCompilationFailed(let listName, let component) = event {
@@ -217,6 +219,8 @@ class ContentBlockerRulesManagerLoadingTests: ContentBlockerRulesManagerTests {
} else if case .contentBlockingCompilationTime = event {
XCTAssertNotNil(params?["compilationTime"])
compilationTimeExp.fulfill()
+ } else if case .contentBlockingLRCMissing = event {
+ lookupAndFetchExp.fulfill()
} else {
XCTFail("Unexpected event: \(event)")
}
@@ -227,7 +231,7 @@ class ContentBlockerRulesManagerLoadingTests: ContentBlockerRulesManagerTests {
updateListener: rulesUpdateListener,
errorReporting: errorHandler)
- wait(for: [exp, errorExp, compilationTimeExp], timeout: 15.0)
+ wait(for: [exp, errorExp, compilationTimeExp, lookupAndFetchExp], timeout: 15.0)
XCTAssertNotNil(cbrm.currentRules)
XCTAssertEqual(cbrm.currentRules.first?.etag, mockRulesSource.trackerData?.etag)
@@ -254,6 +258,8 @@ class ContentBlockerRulesManagerLoadingTests: ContentBlockerRulesManagerTests {
}
let errorExp = expectation(description: "Error reported")
+ let lookupAndFetchExp = expectation(description: "Look and Fetch rules failed")
+
let errorHandler = EventMapping { event, _, params, _ in
if case .contentBlockingCompilationFailed(let listName, let component) = event {
XCTAssertEqual(listName, DefaultContentBlockerRulesListsSource.Constants.trackerDataSetRulesListName)
@@ -266,6 +272,8 @@ class ContentBlockerRulesManagerLoadingTests: ContentBlockerRulesManagerTests {
} else if case .contentBlockingCompilationTime = event {
XCTAssertNotNil(params?["compilationTime"])
+ } else if case .contentBlockingLRCMissing = event {
+ lookupAndFetchExp.fulfill()
} else {
XCTFail("Unexpected event: \(event)")
}
@@ -276,7 +284,7 @@ class ContentBlockerRulesManagerLoadingTests: ContentBlockerRulesManagerTests {
updateListener: rulesUpdateListener,
errorReporting: errorHandler)
- wait(for: [exp, errorExp], timeout: 15.0)
+ wait(for: [exp, errorExp, lookupAndFetchExp], timeout: 15.0)
XCTAssertNotNil(cbrm.currentRules)
XCTAssertEqual(cbrm.currentRules.first?.etag, mockRulesSource.embeddedTrackerData.etag)
@@ -539,6 +547,9 @@ class ContentBlockerRulesManagerLoadingTests: ContentBlockerRulesManagerTests {
let errorExp = expectation(description: "Error reported")
errorExp.expectedFulfillmentCount = 5
+
+ let lookupAndFetchExp = expectation(description: "Look and Fetch rules failed")
+
var errorEvents = [ContentBlockerDebugEvents.Component]()
let errorHandler = EventMapping { event, _, params, _ in
if case .contentBlockingCompilationFailed(let listName, let component) = event {
@@ -554,6 +565,8 @@ class ContentBlockerRulesManagerLoadingTests: ContentBlockerRulesManagerTests {
} else if case .contentBlockingCompilationTime = event {
XCTAssertNotNil(params?["compilationTime"])
errorExp.fulfill()
+ } else if case .contentBlockingLRCMissing = event {
+ lookupAndFetchExp.fulfill()
} else {
XCTFail("Unexpected event: \(event)")
}
@@ -564,7 +577,7 @@ class ContentBlockerRulesManagerLoadingTests: ContentBlockerRulesManagerTests {
updateListener: rulesUpdateListener,
errorReporting: errorHandler)
- wait(for: [exp, errorExp], timeout: 15.0)
+ wait(for: [exp, errorExp, lookupAndFetchExp], timeout: 15.0)
XCTAssertEqual(Set(errorEvents), Set([.tds,
.tempUnprotected,
@@ -619,6 +632,9 @@ class ContentBlockerRulesManagerLoadingTests: ContentBlockerRulesManagerTests {
let errorExp = expectation(description: "Error reported")
errorExp.expectedFulfillmentCount = 4
+
+ let lookupAndFetchExp = expectation(description: "Look and Fetch rules failed")
+
var errorEvents = [ContentBlockerDebugEvents.Component]()
let errorHandler = EventMapping { event, _, params, _ in
if case .contentBlockingCompilationFailed(let listName, let component) = event {
@@ -634,7 +650,10 @@ class ContentBlockerRulesManagerLoadingTests: ContentBlockerRulesManagerTests {
} else if case .contentBlockingCompilationTime = event {
XCTAssertNotNil(params?["compilationTime"])
errorExp.fulfill()
- } else {
+ } else if case .contentBlockingLRCMissing = event {
+ lookupAndFetchExp.fulfill()
+ } else
+ {
XCTFail("Unexpected event: \(event)")
}
}
@@ -644,7 +663,7 @@ class ContentBlockerRulesManagerLoadingTests: ContentBlockerRulesManagerTests {
updateListener: rulesUpdateListener,
errorReporting: errorHandler)
- wait(for: [exp, errorExp], timeout: 15.0)
+ wait(for: [exp, errorExp, lookupAndFetchExp], timeout: 15.0)
XCTAssertEqual(Set(errorEvents), Set([.tempUnprotected,
.allowlist,
diff --git a/Tests/BrowserServicesKitTests/ContentBlocker/UserContentControllerTests.swift b/Tests/BrowserServicesKitTests/ContentBlocker/UserContentControllerTests.swift
index e1cb475cd..767250149 100644
--- a/Tests/BrowserServicesKitTests/ContentBlocker/UserContentControllerTests.swift
+++ b/Tests/BrowserServicesKitTests/ContentBlocker/UserContentControllerTests.swift
@@ -318,18 +318,26 @@ class PrivacyConfigurationMock: PrivacyConfiguration {
return .enabled
}
- func isSubfeatureEnabled(
- _ subfeature: any PrivacySubfeature,
- versionProvider: AppVersionProvider,
- randomizer: (Range) -> Double
- ) -> Bool {
+ func isSubfeatureEnabled(_ subfeature: any BrowserServicesKit.PrivacySubfeature, versionProvider: BrowserServicesKit.AppVersionProvider, randomizer: (Range) -> Double) -> Bool {
true
}
- func stateFor(_ subfeature: any PrivacySubfeature, versionProvider: AppVersionProvider, randomizer: (Range) -> Double) -> PrivacyConfigurationFeatureState {
+ func stateFor(_ subfeature: any BrowserServicesKit.PrivacySubfeature, versionProvider: BrowserServicesKit.AppVersionProvider, randomizer: (Range) -> Double) -> BrowserServicesKit.PrivacyConfigurationFeatureState {
return .enabled
}
+ func stateFor(subfeatureID: BrowserServicesKit.SubfeatureID, parentFeatureID: BrowserServicesKit.ParentFeatureID, versionProvider: BrowserServicesKit.AppVersionProvider, randomizer: (Range) -> Double) -> BrowserServicesKit.PrivacyConfigurationFeatureState {
+ return .enabled
+ }
+
+ func cohorts(for subfeature: any BrowserServicesKit.PrivacySubfeature) -> [BrowserServicesKit.PrivacyConfigurationData.Cohort]? {
+ return nil
+ }
+
+ func cohorts(subfeatureID: BrowserServicesKit.SubfeatureID, parentFeatureID: BrowserServicesKit.ParentFeatureID) -> [BrowserServicesKit.PrivacyConfigurationData.Cohort]? {
+ return nil
+ }
+
var identifier: String = "abcd"
var version: String? = "123456789"
var userUnprotectedDomains: [String] = []
diff --git a/Tests/BrowserServicesKitTests/ContentBlocker/WebViewTestHelper.swift b/Tests/BrowserServicesKitTests/ContentBlocker/WebViewTestHelper.swift
index 31370ce4a..bc979233e 100644
--- a/Tests/BrowserServicesKitTests/ContentBlocker/WebViewTestHelper.swift
+++ b/Tests/BrowserServicesKitTests/ContentBlocker/WebViewTestHelper.swift
@@ -225,3 +225,11 @@ final class WebKitTestHelper {
}
}
}
+
+class MockExperimentCohortsManager: ExperimentCohortsManaging {
+ func resolveCohort(for experiment: ExperimentSubfeature, allowCohortReassignment: Bool) -> CohortID? {
+ return nil
+ }
+
+ var experiments: Experiments?
+}
diff --git a/Tests/BrowserServicesKitTests/FeatureFlagging/DefaultFeatureFlaggerTests.swift b/Tests/BrowserServicesKitTests/FeatureFlagging/DefaultFeatureFlaggerTests.swift
index 3eb2db17e..6f5fffbd4 100644
--- a/Tests/BrowserServicesKitTests/FeatureFlagging/DefaultFeatureFlaggerTests.swift
+++ b/Tests/BrowserServicesKitTests/FeatureFlagging/DefaultFeatureFlaggerTests.swift
@@ -52,15 +52,18 @@ final class CapturingFeatureFlagOverriding: FeatureFlagLocalOverriding {
final class DefaultFeatureFlaggerTests: XCTestCase {
var internalUserDeciderStore: MockInternalUserStoring!
+ var experimentManager: MockExperimentManager!
var overrides: CapturingFeatureFlagOverriding!
override func setUp() {
super.setUp()
internalUserDeciderStore = MockInternalUserStoring()
+ experimentManager = MockExperimentManager()
}
override func tearDown() {
internalUserDeciderStore = nil
+ experimentManager = nil
super.tearDown()
}
@@ -72,9 +75,9 @@ final class DefaultFeatureFlaggerTests: XCTestCase {
func testWhenInternalOnly_returnsIsInternalUserValue() {
let featureFlagger = createFeatureFlagger()
internalUserDeciderStore.isInternalUser = false
- XCTAssertFalse(featureFlagger.isFeatureOn(for: FeatureFlagSource.internalOnly))
+ XCTAssertFalse(featureFlagger.isFeatureOn(for: FeatureFlagSource.internalOnly()))
internalUserDeciderStore.isInternalUser = true
- XCTAssertTrue(featureFlagger.isFeatureOn(for: FeatureFlagSource.internalOnly))
+ XCTAssertTrue(featureFlagger.isFeatureOn(for: FeatureFlagSource.internalOnly()))
}
func testWhenRemoteDevelopment_isNOTInternalUser_returnsFalse() {
@@ -141,6 +144,128 @@ final class DefaultFeatureFlaggerTests: XCTestCase {
assertFeatureFlagger(with: embeddedData, willReturn: false, for: sourceProvider)
}
+ // MARK: - Experiments
+
+ func testWhenGetCohortIfEnabled_andSourceDisabled_returnsNil() {
+ let featureFlagger = createFeatureFlagger()
+ let flag = FakeExperimentFlag(source: .disabled)
+ let cohort = featureFlagger.getCohortIfEnabled(for: flag)
+ XCTAssertNil(cohort)
+ }
+
+ func estWhenGetCohortIfEnabled_andSourceInternal_returnsPassedCohort() {
+ let featureFlagger = createFeatureFlagger()
+ let flag = FakeExperimentFlag(source: .internalOnly(AutofillCohort.blue))
+ let cohort = featureFlagger.getCohortIfEnabled(for: flag)
+ XCTAssertEqual(cohort?.rawValue, AutofillCohort.blue.rawValue)
+ }
+
+ func testWhenGetCohortIfEnabled_andRemoteInternal_andInternalStateTrue_and_cohortAssigned_returnsAssignedCohort() {
+ internalUserDeciderStore.isInternalUser = true
+ let subfeature = AutofillSubfeature.credentialsAutofill
+ experimentManager.cohortToReturn = AutofillCohort.control.rawValue
+ let embeddedData = Self.embeddedConfig(autofillSubfeatureForState: (subfeature: subfeature, state: "enabled"))
+
+ let flag = FakeExperimentFlag(source: .remoteDevelopment(.subfeature(AutofillSubfeature.credentialsAutofill)))
+ let featureFlagger = createFeatureFlagger(withMockedConfigData: embeddedData)
+ let cohort = featureFlagger.getCohortIfEnabled(for: flag)
+ XCTAssertEqual(cohort?.rawValue, AutofillCohort.control.rawValue)
+ }
+
+ func testWhenGetCohortIfEnabled_andRemoteInternal_andInternalStateFalse_and_cohortAssigned_returnsNil() {
+ internalUserDeciderStore.isInternalUser = false
+ let subfeature = AutofillSubfeature.credentialsAutofill
+ experimentManager.cohortToReturn = AutofillCohort.control.rawValue
+ let embeddedData = Self.embeddedConfig(autofillSubfeatureForState: (subfeature: subfeature, state: "enabled"))
+
+ let flag = FakeExperimentFlag(source: .remoteDevelopment(.subfeature(AutofillSubfeature.credentialsAutofill)))
+ let featureFlagger = createFeatureFlagger(withMockedConfigData: embeddedData)
+ let cohort = featureFlagger.getCohortIfEnabled(for: flag)
+ XCTAssertNil(cohort)
+ }
+
+ func testWhenGetCohortIfEnabled_andRemoteInternal_andInternalStateTrue_and_cohortAssigned_andFeaturePassed_returnsNil() {
+ internalUserDeciderStore.isInternalUser = true
+ let subfeature = AutofillSubfeature.credentialsAutofill
+ experimentManager.cohortToReturn = AutofillCohort.control.rawValue
+ let embeddedData = Self.embeddedConfig(autofillSubfeatureForState: (subfeature: subfeature, state: "enabled"))
+
+ let flag = FakeExperimentFlag(source: .remoteDevelopment(.feature(.autofill)))
+ let featureFlagger = createFeatureFlagger(withMockedConfigData: embeddedData)
+ let cohort = featureFlagger.getCohortIfEnabled(for: flag)
+ XCTAssertNil(cohort)
+ }
+
+ func testWhenGetCohortIfEnabled_andRemoteInternal_andInternalStateTrue_and_cohortNotAssigned_returnsNil() {
+ internalUserDeciderStore.isInternalUser = true
+ let subfeature = AutofillSubfeature.credentialsAutofill
+ experimentManager.cohortToReturn = nil
+ let embeddedData = Self.embeddedConfig(autofillSubfeatureForState: (subfeature: subfeature, state: "enabled"))
+
+ let flag = FakeExperimentFlag(source: .remoteDevelopment(.subfeature(AutofillSubfeature.credentialsAutofill)))
+ let featureFlagger = createFeatureFlagger(withMockedConfigData: embeddedData)
+ let cohort = featureFlagger.getCohortIfEnabled(for: flag)
+ XCTAssertNil(cohort)
+ }
+
+ func testWhenGetCohortIfEnabled_andRemoteInternal_andInternalStateTrue_and_cohortAssignedButNorMatchingEnum_returnsNil() {
+ internalUserDeciderStore.isInternalUser = true
+ let subfeature = AutofillSubfeature.credentialsAutofill
+ experimentManager.cohortToReturn = "some"
+ let embeddedData = Self.embeddedConfig(autofillSubfeatureForState: (subfeature: subfeature, state: "enabled"))
+
+ let flag = FakeExperimentFlag(source: .remoteDevelopment(.subfeature(AutofillSubfeature.credentialsAutofill)))
+ let featureFlagger = createFeatureFlagger(withMockedConfigData: embeddedData)
+ let cohort = featureFlagger.getCohortIfEnabled(for: flag)
+ XCTAssertNil(cohort)
+ }
+
+ func testWhenGetCohortIfEnabled_andRemoteReleasable_and_cohortAssigned_returnsAssignedCohort() {
+ let subfeature = AutofillSubfeature.credentialsAutofill
+ experimentManager.cohortToReturn = AutofillCohort.control.rawValue
+ let embeddedData = Self.embeddedConfig(autofillSubfeatureForState: (subfeature: subfeature, state: "enabled"))
+
+ let flag = FakeExperimentFlag(source: .remoteReleasable(.subfeature(AutofillSubfeature.credentialsAutofill)))
+ let featureFlagger = createFeatureFlagger(withMockedConfigData: embeddedData)
+ let cohort = featureFlagger.getCohortIfEnabled(for: flag)
+ XCTAssertEqual(cohort?.rawValue, AutofillCohort.control.rawValue)
+ }
+
+ func testWhenGetCohortIfEnabled_andRemoteReleasable_and_cohortAssigned_andFeaturePassed_returnsNil() {
+ let subfeature = AutofillSubfeature.credentialsAutofill
+ experimentManager.cohortToReturn = AutofillCohort.control.rawValue
+ let embeddedData = Self.embeddedConfig(autofillSubfeatureForState: (subfeature: subfeature, state: "enabled"))
+
+ let flag = FakeExperimentFlag(source: .remoteReleasable(.feature(.autofill)))
+ let featureFlagger = createFeatureFlagger(withMockedConfigData: embeddedData)
+ let cohort = featureFlagger.getCohortIfEnabled(for: flag)
+ XCTAssertNil(cohort)
+ }
+
+ func testWhenGetCohortIfEnabled_andRemoteReleasable_and_cohortNotAssigned_andFeaturePassed_returnsNil() {
+ internalUserDeciderStore.isInternalUser = true
+ let subfeature = AutofillSubfeature.credentialsAutofill
+ experimentManager.cohortToReturn = nil
+ let embeddedData = Self.embeddedConfig(autofillSubfeatureForState: (subfeature: subfeature, state: "enabled"))
+
+ let flag = FakeExperimentFlag(source: .remoteReleasable(.subfeature(AutofillSubfeature.credentialsAutofill)))
+ let featureFlagger = createFeatureFlagger(withMockedConfigData: embeddedData)
+ let cohort = featureFlagger.getCohortIfEnabled(for: flag)
+ XCTAssertNil(cohort)
+ }
+
+ func testWhenGetCohortIfEnabled_andRemoteReleasable_and_cohortAssignedButNotMatchingEnum_returnsNil() {
+ internalUserDeciderStore.isInternalUser = true
+ let subfeature = AutofillSubfeature.credentialsAutofill
+ experimentManager.cohortToReturn = "some"
+ let embeddedData = Self.embeddedConfig(autofillSubfeatureForState: (subfeature: subfeature, state: "enabled"))
+
+ let flag = FakeExperimentFlag(source: .remoteReleasable(.subfeature(AutofillSubfeature.credentialsAutofill)))
+ let featureFlagger = createFeatureFlagger(withMockedConfigData: embeddedData)
+ let cohort = featureFlagger.getCohortIfEnabled(for: flag)
+ XCTAssertNil(cohort)
+ }
+
// MARK: - Overrides
func testWhenFeatureFlaggerIsInitializedWithLocalOverridesAndUserIsNotInternalThenAllFlagsAreCleared() throws {
@@ -186,7 +311,7 @@ final class DefaultFeatureFlaggerTests: XCTestCase {
localProtection: MockDomainsProtectionStore(),
internalUserDecider: DefaultInternalUserDecider())
let internalUserDecider = DefaultInternalUserDecider(store: internalUserDeciderStore)
- return DefaultFeatureFlagger(internalUserDecider: internalUserDecider, privacyConfigManager: manager)
+ return DefaultFeatureFlagger(internalUserDecider: internalUserDecider, privacyConfigManager: manager, experimentManager: experimentManager)
}
private func createFeatureFlaggerWithLocalOverrides(withMockedConfigData data: Data = DefaultFeatureFlaggerTests.embeddedConfig()) -> DefaultFeatureFlagger {
@@ -203,6 +328,7 @@ final class DefaultFeatureFlaggerTests: XCTestCase {
internalUserDecider: internalUserDecider,
privacyConfigManager: manager,
localOverrides: overrides,
+ experimentManager: nil,
for: TestFeatureFlag.self
)
}
@@ -243,3 +369,25 @@ extension FeatureFlagSource: FeatureFlagDescribing {
public var rawValue: String { "rawValue" }
public var source: FeatureFlagSource { self }
}
+
+class MockExperimentManager: ExperimentCohortsManaging {
+ var cohortToReturn: CohortID?
+ var experiments: BrowserServicesKit.Experiments?
+
+ func resolveCohort(for experiment: BrowserServicesKit.ExperimentSubfeature, allowCohortReassignment: Bool) -> CohortID? {
+ return cohortToReturn
+ }
+}
+
+private struct FakeExperimentFlag: FeatureFlagExperimentDescribing {
+ typealias CohortType = AutofillCohort
+
+ var rawValue: String = "fake-experiment"
+
+ var source: FeatureFlagSource
+}
+
+private enum AutofillCohort: String, FlagCohort {
+ case control
+ case blue
+}
diff --git a/Tests/BrowserServicesKitTests/FeatureFlagging/ExperimentCohortsManagerTests.swift b/Tests/BrowserServicesKitTests/FeatureFlagging/ExperimentCohortsManagerTests.swift
new file mode 100644
index 000000000..48fc85355
--- /dev/null
+++ b/Tests/BrowserServicesKitTests/FeatureFlagging/ExperimentCohortsManagerTests.swift
@@ -0,0 +1,189 @@
+//
+// ExperimentCohortsManagerTests.swift
+//
+// Copyright © 2024 DuckDuckGo. All rights reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//
+
+import XCTest
+@testable import BrowserServicesKit
+
+final class ExperimentCohortsManagerTests: XCTestCase {
+
+ let cohort1 = PrivacyConfigurationData.Cohort(json: ["name": "Cohort1", "weight": 1])!
+ let cohort2 = PrivacyConfigurationData.Cohort(json: ["name": "Cohort2", "weight": 0])!
+ let cohort3 = PrivacyConfigurationData.Cohort(json: ["name": "Cohort3", "weight": 2])!
+ let cohort4 = PrivacyConfigurationData.Cohort(json: ["name": "Cohort4", "weight": 0])!
+
+ var mockStore: MockExperimentDataStore!
+ var experimentCohortsManager: ExperimentCohortsManager!
+
+ let subfeatureName1 = "TestSubfeature1"
+ var experimentData1: ExperimentData!
+
+ let subfeatureName2 = "TestSubfeature2"
+ var experimentData2: ExperimentData!
+
+ let subfeatureName3 = "TestSubfeature3"
+ var experimentData3: ExperimentData!
+
+ let subfeatureName4 = "TestSubfeature4"
+ var experimentData4: ExperimentData!
+
+ let encoder: JSONEncoder = {
+ let encoder = JSONEncoder()
+ encoder.dateEncodingStrategy = .secondsSince1970
+ return encoder
+ }()
+
+ override func setUp() {
+ super.setUp()
+ mockStore = MockExperimentDataStore()
+ experimentCohortsManager = ExperimentCohortsManager(
+ store: mockStore
+ )
+
+ let expectedDate1 = Date()
+ experimentData1 = ExperimentData(parentID: "TestParent", cohortID: cohort1.name, enrollmentDate: expectedDate1)
+
+ let expectedDate2 = Date().addingTimeInterval(60)
+ experimentData2 = ExperimentData(parentID: "TestParent", cohortID: cohort2.name, enrollmentDate: expectedDate2)
+
+ let expectedDate3 = Date()
+ experimentData3 = ExperimentData(parentID: "TestParent", cohortID: cohort3.name, enrollmentDate: expectedDate3)
+
+ let expectedDate4 = Date().addingTimeInterval(60)
+ experimentData4 = ExperimentData(parentID: "TestParent", cohortID: cohort4.name, enrollmentDate: expectedDate4)
+ }
+
+ override func tearDown() {
+ mockStore = nil
+ experimentCohortsManager = nil
+ experimentData1 = nil
+ experimentData2 = nil
+ super.tearDown()
+ }
+
+ func testExperimentReturnAssignedExperiments() {
+ // GIVEN
+ mockStore.experiments = [subfeatureName1: experimentData1, subfeatureName2: experimentData2]
+
+ // WHEN
+ let experiments = experimentCohortsManager.experiments
+
+ // THEN
+ XCTAssertEqual(experiments?.count, 2)
+ XCTAssertEqual(experiments?[subfeatureName1], experimentData1)
+ XCTAssertEqual(experiments?[subfeatureName2], experimentData2)
+ XCTAssertNil(experiments?[subfeatureName3])
+ }
+
+ func testCohortReturnsCohortIDIfExistsForMultipleSubfeatures() {
+ // GIVEN
+ mockStore.experiments = [subfeatureName1: experimentData1, subfeatureName2: experimentData2]
+
+ // WHEN
+ let result1 = experimentCohortsManager.resolveCohort(for: ExperimentSubfeature(parentID: experimentData1.parentID, subfeatureID: subfeatureName1, cohorts: [cohort1, cohort2]), allowCohortReassignment: false)
+ let result2 = experimentCohortsManager.resolveCohort(for: ExperimentSubfeature(parentID: experimentData2.parentID, subfeatureID: subfeatureName2, cohorts: [cohort2, cohort3]), allowCohortReassignment: false)
+
+ // THEN
+ XCTAssertEqual(result1, experimentData1.cohortID)
+ XCTAssertEqual(result2, experimentData2.cohortID)
+ }
+
+ func testCohortAssignIfEnabledWhenNoCohortExists() {
+ // GIVEN
+ mockStore.experiments = [:]
+ let cohorts = [cohort1, cohort2]
+ let experiment = ExperimentSubfeature(parentID: experimentData1.parentID, subfeatureID: subfeatureName1, cohorts: cohorts)
+
+ // WHEN
+ let result = experimentCohortsManager.resolveCohort(for: experiment, allowCohortReassignment: true)
+
+ // THEN
+ XCTAssertNotNil(result)
+ XCTAssertEqual(result, experimentData1.cohortID)
+ }
+
+ func testCohortDoesNotAssignIfAssignIfEnabledIsFalse() {
+ // GIVEN
+ mockStore.experiments = [:]
+ let cohorts = [cohort1, cohort2]
+ let experiment = ExperimentSubfeature(parentID: experimentData1.parentID, subfeatureID: subfeatureName1, cohorts: cohorts)
+
+ // WHEN
+ let result = experimentCohortsManager.resolveCohort(for: experiment, allowCohortReassignment: false)
+
+ // THEN
+ XCTAssertNil(result)
+ }
+
+ func testCohortDoesNotAssignIfAssignIfEnabledIsTrueButNoCohortsAvailable() {
+ // GIVEN
+ mockStore.experiments = [:]
+ let experiment = ExperimentSubfeature(parentID: "TestParent", subfeatureID: "NonExistentSubfeature", cohorts: [])
+
+ // WHEN
+ let result = experimentCohortsManager.resolveCohort(for: experiment, allowCohortReassignment: true)
+
+ // THEN
+ XCTAssertNil(result)
+ }
+
+ func testCohortReassignsCohortIfAssignedCohortDoesNotExistAndAssignIfEnabledIsTrue() {
+ // GIVEN
+ mockStore.experiments = [subfeatureName1: experimentData1]
+
+ // WHEN
+ let result1 = experimentCohortsManager.resolveCohort(for: ExperimentSubfeature(parentID: experimentData1.parentID, subfeatureID: subfeatureName1, cohorts: [cohort2, cohort3]), allowCohortReassignment: true)
+
+ // THEN
+ XCTAssertEqual(result1, experimentData3.cohortID)
+ }
+
+ func testCohortDoesNotReassignsCohortIfAssignedCohortDoesNotExistAndAssignIfEnabledIsTrue() {
+ // GIVEN
+ mockStore.experiments = [subfeatureName1: experimentData1]
+
+ // WHEN
+ let result1 = experimentCohortsManager.resolveCohort(for: ExperimentSubfeature(parentID: experimentData1.parentID, subfeatureID: subfeatureName1, cohorts: [cohort2, cohort3]), allowCohortReassignment: false)
+
+ // THEN
+ XCTAssertNil(result1)
+ }
+
+ func testCohortAssignsBasedOnWeight() {
+ // GIVEN
+ let experiment = ExperimentSubfeature(parentID: experimentData3.parentID, subfeatureID: subfeatureName3, cohorts: [cohort3, cohort4])
+
+ let randomizer: (Range) -> Double = { range in
+ return 1.5
+ }
+
+ experimentCohortsManager = ExperimentCohortsManager(
+ store: mockStore,
+ randomizer: randomizer
+ )
+
+ // WHEN
+ let result = experimentCohortsManager.resolveCohort(for: experiment, allowCohortReassignment: true)
+
+ // THEN
+ XCTAssertEqual(result, experimentData3.cohortID)
+ }
+}
+
+class MockExperimentDataStore: ExperimentsDataStoring {
+ var experiments: Experiments?
+}
diff --git a/Tests/BrowserServicesKitTests/PrivacyConfig/ExperimentsDataStoreTests.swift b/Tests/BrowserServicesKitTests/FeatureFlagging/ExperimentsDataStoreTests.swift
similarity index 80%
rename from Tests/BrowserServicesKitTests/PrivacyConfig/ExperimentsDataStoreTests.swift
rename to Tests/BrowserServicesKitTests/FeatureFlagging/ExperimentsDataStoreTests.swift
index 0466155b0..77da51663 100644
--- a/Tests/BrowserServicesKitTests/PrivacyConfig/ExperimentsDataStoreTests.swift
+++ b/Tests/BrowserServicesKitTests/FeatureFlagging/ExperimentsDataStoreTests.swift
@@ -47,8 +47,8 @@ final class ExperimentsDataStoreTests: XCTestCase {
func testExperimentsGetReturnsDecodedExperiments() {
// GIVEN
- let experimentData1 = ExperimentData(cohort: "TestCohort1", enrollmentDate: Date())
- let experimentData2 = ExperimentData(cohort: "TestCohort2", enrollmentDate: Date())
+ let experimentData1 = ExperimentData(parentID: "parent", cohortID: "TestCohort1", enrollmentDate: Date())
+ let experimentData2 = ExperimentData(parentID: "parent", cohortID: "TestCohort2", enrollmentDate: Date())
let experiments = [subfeatureName1: experimentData1, subfeatureName2: experimentData2]
let encoder = JSONEncoder()
@@ -62,17 +62,17 @@ final class ExperimentsDataStoreTests: XCTestCase {
// THEN
let timeDifference1 = abs(experimentData1.enrollmentDate.timeIntervalSince(result?[subfeatureName1]?.enrollmentDate ?? Date()))
let timeDifference2 = abs(experimentData2.enrollmentDate.timeIntervalSince(result?[subfeatureName2]?.enrollmentDate ?? Date()))
- XCTAssertEqual(result?[subfeatureName1]?.cohort, experimentData1.cohort)
+ XCTAssertEqual(result?[subfeatureName1]?.cohortID, experimentData1.cohortID)
XCTAssertLessThanOrEqual(timeDifference1, 1.0)
- XCTAssertEqual(result?[subfeatureName2]?.cohort, experimentData2.cohort)
+ XCTAssertEqual(result?[subfeatureName2]?.cohortID, experimentData2.cohortID)
XCTAssertLessThanOrEqual(timeDifference2, 1.0)
}
func testExperimentsSetEncodesAndStoresData() throws {
// GIVEN
- let experimentData1 = ExperimentData(cohort: "TestCohort1", enrollmentDate: Date())
- let experimentData2 = ExperimentData(cohort: "TestCohort2", enrollmentDate: Date())
+ let experimentData1 = ExperimentData(parentID: "parent", cohortID: "TestCohort1", enrollmentDate: Date())
+ let experimentData2 = ExperimentData(parentID: "parent2", cohortID: "TestCohort2", enrollmentDate: Date())
let experiments = [subfeatureName1: experimentData1, subfeatureName2: experimentData2]
// WHEN
@@ -85,9 +85,11 @@ final class ExperimentsDataStoreTests: XCTestCase {
let decodedExperiments = try? decoder.decode(Experiments.self, from: storedData)
let timeDifference1 = abs(experimentData1.enrollmentDate.timeIntervalSince(decodedExperiments?[subfeatureName1]?.enrollmentDate ?? Date()))
let timeDifference2 = abs(experimentData2.enrollmentDate.timeIntervalSince(decodedExperiments?[subfeatureName2]?.enrollmentDate ?? Date()))
- XCTAssertEqual(decodedExperiments?[subfeatureName1]?.cohort, experimentData1.cohort)
+ XCTAssertEqual(decodedExperiments?[subfeatureName1]?.cohortID, experimentData1.cohortID)
+ XCTAssertEqual(decodedExperiments?[subfeatureName1]?.parentID, experimentData1.parentID)
XCTAssertLessThanOrEqual(timeDifference1, 1.0)
- XCTAssertEqual(decodedExperiments?[subfeatureName2]?.cohort, experimentData2.cohort)
+ XCTAssertEqual(decodedExperiments?[subfeatureName2]?.cohortID, experimentData2.cohortID)
+ XCTAssertEqual(decodedExperiments?[subfeatureName2]?.parentID, experimentData2.parentID)
XCTAssertLessThanOrEqual(timeDifference2, 1.0)
}
}
diff --git a/Tests/BrowserServicesKitTests/FeatureFlagging/FeatureFlagLocalOverridesTests.swift b/Tests/BrowserServicesKitTests/FeatureFlagging/FeatureFlagLocalOverridesTests.swift
index 5e2407ca1..5223ff059 100644
--- a/Tests/BrowserServicesKitTests/FeatureFlagging/FeatureFlagLocalOverridesTests.swift
+++ b/Tests/BrowserServicesKitTests/FeatureFlagging/FeatureFlagLocalOverridesTests.swift
@@ -46,7 +46,7 @@ final class FeatureFlagLocalOverridesTests: XCTestCase {
let internalUserDecider = DefaultInternalUserDecider(store: internalUserDeciderStore)
let privacyConfig = MockPrivacyConfiguration()
let privacyConfigManager = MockPrivacyConfigurationManager(privacyConfig: privacyConfig, internalUserDecider: internalUserDecider)
- featureFlagger = DefaultFeatureFlagger(internalUserDecider: internalUserDecider, privacyConfigManager: privacyConfigManager)
+ featureFlagger = DefaultFeatureFlagger(internalUserDecider: internalUserDecider, privacyConfigManager: privacyConfigManager, experimentManager: nil)
keyValueStore = MockKeyValueStore()
actionHandler = CapturingFeatureFlagLocalOverridesHandler()
diff --git a/Tests/BrowserServicesKitTests/FeatureFlagging/FeatureFlaggerExperimentsTests.swift b/Tests/BrowserServicesKitTests/FeatureFlagging/FeatureFlaggerExperimentsTests.swift
new file mode 100644
index 000000000..b6931ef22
--- /dev/null
+++ b/Tests/BrowserServicesKitTests/FeatureFlagging/FeatureFlaggerExperimentsTests.swift
@@ -0,0 +1,1228 @@
+//
+// FeatureFlaggerExperimentsTests.swift
+//
+// Copyright © 2024 DuckDuckGo. All rights reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//
+
+import XCTest
+@testable import BrowserServicesKit
+
+struct CredentialsSavingFlag: FeatureFlagExperimentDescribing {
+
+ typealias CohortType = Cohort
+
+ var rawValue = "credentialSaving"
+
+ var source: FeatureFlagSource = .remoteReleasable(.subfeature(AutofillSubfeature.credentialsSaving))
+
+ enum Cohort: String, FlagCohort {
+ case control
+ case blue
+ case red
+ }
+}
+
+struct InlineIconCredentialsFlag: FeatureFlagExperimentDescribing {
+
+ typealias CohortType = Cohort
+
+ var rawValue = "inlineIconCredentials"
+
+ var source: FeatureFlagSource = .remoteReleasable(.subfeature(AutofillSubfeature.inlineIconCredentials))
+
+ enum Cohort: String, FlagCohort {
+ case control
+ case blue
+ case green
+ }
+}
+
+struct AccessCredentialManagementFlag: FeatureFlagExperimentDescribing {
+
+ typealias CohortType = Cohort
+
+ var rawValue = "accessCredentialManagement"
+
+ var source: FeatureFlagSource = .remoteReleasable(.subfeature(AutofillSubfeature.accessCredentialManagement))
+
+ enum Cohort: String, FlagCohort {
+ case control
+ case blue
+ case green
+ }
+}
+
+final class FeatureFlaggerExperimentsTests: XCTestCase {
+
+ var featureJson: Data = "{}".data(using: .utf8)!
+ var mockEmbeddedData: MockEmbeddedDataProvider!
+ var mockStore: MockExperimentDataStore!
+ var experimentManager: ExperimentCohortsManager!
+ var manager: PrivacyConfigurationManager!
+ var locale: Locale!
+ var featureFlagger: FeatureFlagger!
+
+ let subfeatureName = "credentialsSaving"
+
+ override func setUp() {
+ locale = Locale(identifier: "fr_US")
+ mockEmbeddedData = MockEmbeddedDataProvider(data: featureJson, etag: "test")
+ let mockInternalUserStore = MockInternalUserStoring()
+ mockStore = MockExperimentDataStore()
+ experimentManager = ExperimentCohortsManager(store: mockStore)
+ manager = PrivacyConfigurationManager(fetchedETag: nil,
+ fetchedData: nil,
+ embeddedDataProvider: mockEmbeddedData,
+ localProtection: MockDomainsProtectionStore(),
+ internalUserDecider: DefaultInternalUserDecider(store: mockInternalUserStore),
+ locale: locale)
+ featureFlagger = DefaultFeatureFlagger(internalUserDecider: DefaultInternalUserDecider(store: mockInternalUserStore), privacyConfigManager: manager, experimentManager: experimentManager)
+ }
+
+ override func tearDown() {
+ featureJson = "".data(using: .utf8)!
+ mockEmbeddedData = nil
+ mockStore = nil
+ experimentManager = nil
+ manager = nil
+ }
+
+ func testCohortOnlyAssignedWhenCallingStateForSubfeature() {
+ featureJson =
+ """
+ {
+ "features": {
+ "autofill": {
+ "state": "enabled",
+ "exceptions": [],
+ "features": {
+ "credentialsSaving": {
+ "state": "enabled",
+ "minSupportedVersion": 2,
+ "cohorts": [
+ {
+ "name": "control",
+ "weight": 1
+ },
+ {
+ "name": "blue",
+ "weight": 0
+ }
+ ]
+ }
+ }
+ }
+ }
+ }
+ """.data(using: .utf8)!
+ manager.reload(etag: "", data: featureJson)
+ let config = manager.privacyConfig
+
+ // we haven't called getCohortIfEnabled yet, so cohorts should not be yet assigned
+ XCTAssertNil(mockStore.experiments)
+ XCTAssertNil(experimentManager.cohort(for: subfeatureName))
+
+ // we call isSubfeatureEnabled() hould not be assigned either
+ XCTAssertTrue(config.isSubfeatureEnabled(AutofillSubfeature.credentialsSaving))
+ XCTAssertEqual(config.stateFor(AutofillSubfeature.credentialsSaving), .enabled)
+ XCTAssertNil(mockStore.experiments)
+ XCTAssertNil(experimentManager.cohort(for: subfeatureName))
+
+ // we call getCohortIfEnabled(cohort), then we should assign cohort
+ XCTAssertEqual(featureFlagger.getCohortIfEnabled(for: CredentialsSavingFlag())?.rawValue, CredentialsSavingFlag.CohortType.control.rawValue)
+ XCTAssertFalse(mockStore.experiments?.isEmpty ?? true)
+ XCTAssertEqual(experimentManager.cohort(for: subfeatureName), "control")
+ }
+
+ func testRemoveAllCohortsRemotelyRemovesAssignedCohort() {
+ featureJson =
+ """
+ {
+ "features": {
+ "autofill": {
+ "state": "enabled",
+ "exceptions": [],
+ "features": {
+ "credentialsSaving": {
+ "state": "enabled",
+ "minSupportedVersion": 2,
+ "cohorts": [
+ {
+ "name": "control",
+ "weight": 1
+ },
+ {
+ "name": "blue",
+ "weight": 0
+ }
+ ]
+ }
+ }
+ }
+ }
+ }
+ """.data(using: .utf8)!
+ manager.reload(etag: "", data: featureJson)
+ var config = manager.privacyConfig
+
+ XCTAssertEqual(featureFlagger.getCohortIfEnabled(for: CredentialsSavingFlag())?.rawValue, CredentialsSavingFlag.CohortType.control.rawValue)
+ XCTAssertFalse(mockStore.experiments?.isEmpty ?? true)
+ XCTAssertEqual(experimentManager.cohort(for: subfeatureName), "control")
+
+ // remove blue cohort
+ featureJson =
+ """
+ {
+ "features": {
+ "autofill": {
+ "state": "enabled",
+ "exceptions": [],
+ "features": {
+ "credentialsSaving": {
+ "state": "enabled",
+ "minSupportedVersion": 2,
+ "cohorts": [
+ {
+ "name": "control",
+ "weight": 1
+ }
+ ]
+ }
+ }
+ }
+ }
+ }
+ """.data(using: .utf8)!
+ manager.reload(etag: "", data: featureJson)
+ config = manager.privacyConfig
+
+ XCTAssertEqual(featureFlagger.getCohortIfEnabled(for: CredentialsSavingFlag())?.rawValue, CredentialsSavingFlag.CohortType.control.rawValue)
+ XCTAssertFalse(mockStore.experiments?.isEmpty ?? true)
+ XCTAssertEqual(experimentManager.cohort(for: subfeatureName), "control")
+
+ // remove all remaining cohorts
+ featureJson =
+ """
+ {
+ "features": {
+ "autofill": {
+ "state": "enabled",
+ "exceptions": [],
+ "features": {
+ "credentialsSaving": {
+ "state": "enabled",
+ "minSupportedVersion": 2
+ }
+ }
+ }
+ }
+ }
+ """.data(using: .utf8)!
+ manager.reload(etag: "", data: featureJson)
+ config = manager.privacyConfig
+
+ XCTAssertNil(featureFlagger.getCohortIfEnabled(for: CredentialsSavingFlag()))
+ XCTAssertTrue(mockStore.experiments?.isEmpty ?? false)
+ XCTAssertNil(experimentManager.cohort(for: subfeatureName))
+ XCTAssertTrue(config.isSubfeatureEnabled(AutofillSubfeature.credentialsSaving))
+ }
+
+ func testRemoveAssignedCohortsRemotelyRemovesAssignedCohortAndTriesToReassign() {
+ featureJson =
+ """
+ {
+ "features": {
+ "autofill": {
+ "state": "enabled",
+ "exceptions": [],
+ "features": {
+ "credentialsSaving": {
+ "state": "enabled",
+ "minSupportedVersion": 2,
+ "cohorts": [
+ {
+ "name": "control",
+ "weight": 1
+ },
+ {
+ "name": "blue",
+ "weight": 0
+ }
+ ]
+ }
+ }
+ }
+ }
+ }
+ """.data(using: .utf8)!
+ manager.reload(etag: "2", data: featureJson)
+
+ XCTAssertEqual(featureFlagger.getCohortIfEnabled(for: CredentialsSavingFlag())?.rawValue, CredentialsSavingFlag.CohortType.control.rawValue)
+ XCTAssertFalse(mockStore.experiments?.isEmpty ?? true)
+ XCTAssertEqual(experimentManager.cohort(for: subfeatureName), "control")
+
+ // remove blue cohort
+ featureJson =
+ """
+ {
+ "features": {
+ "autofill": {
+ "state": "enabled",
+ "exceptions": [],
+ "features": {
+ "credentialsSaving": {
+ "state": "enabled",
+ "minSupportedVersion": 2,
+ "cohorts": [
+ {
+ "name": "red",
+ "weight": 1
+ },
+ {
+ "name": "blue",
+ "weight": 0
+ }
+ ]
+ }
+ }
+ }
+ }
+ }
+ """.data(using: .utf8)!
+ manager.reload(etag: "2", data: featureJson)
+
+ XCTAssertEqual(featureFlagger.getCohortIfEnabled(for: CredentialsSavingFlag())?.rawValue, CredentialsSavingFlag.CohortType.red.rawValue)
+ XCTAssertFalse(mockStore.experiments?.isEmpty ?? true)
+ XCTAssertEqual(experimentManager.cohort(for: subfeatureName), "red")
+ }
+
+ func testDisablingFeatureDisablesCohort() {
+ // Initially subfeature for both cohorts is disabled
+ XCTAssertNil(featureFlagger.getCohortIfEnabled(for: CredentialsSavingFlag()))
+ XCTAssertNil(mockStore.experiments)
+ XCTAssertNil(experimentManager.cohort(for: subfeatureName))
+
+ // When features with cohort the cohort with weight 1 is enabled
+ featureJson =
+ """
+ {
+ "features": {
+ "autofill": {
+ "state": "enabled",
+ "exceptions": [],
+ "features": {
+ "credentialsSaving": {
+ "state": "enabled",
+ "minSupportedVersion": 2,
+ "cohorts": [
+ {
+ "name": "control",
+ "weight": 1
+ },
+ {
+ "name": "blue",
+ "weight": 0
+ }
+ ]
+ }
+ }
+ }
+ }
+ }
+ """.data(using: .utf8)!
+ manager.reload(etag: "", data: featureJson)
+
+ XCTAssertEqual(featureFlagger.getCohortIfEnabled(for: CredentialsSavingFlag())?.rawValue, CredentialsSavingFlag.CohortType.control.rawValue)
+ XCTAssertFalse(mockStore.experiments?.isEmpty ?? true)
+ XCTAssertEqual(experimentManager.cohort(for: subfeatureName), "control")
+
+ // If the subfeature is then disabled isSubfeatureEnabled should return false
+ featureJson =
+ """
+ {
+ "features": {
+ "autofill": {
+ "state": "enabled",
+ "exceptions": [],
+ "features": {
+ "credentialsSaving": {
+ "state": "disabled",
+ "minSupportedVersion": 2,
+ "cohorts": [
+ {
+ "name": "control",
+ "weight": 1
+ },
+ {
+ "name": "blue",
+ "weight": 0
+ }
+ ]
+ }
+ }
+ }
+ }
+ }
+ """.data(using: .utf8)!
+ manager.reload(etag: "", data: featureJson)
+
+ XCTAssertNil(featureFlagger.getCohortIfEnabled(for: CredentialsSavingFlag()))
+ XCTAssertFalse(mockStore.experiments?.isEmpty ?? true)
+ XCTAssertEqual(experimentManager.cohort(for: subfeatureName), "control")
+
+ // If the subfeature is parent feature disabled isSubfeatureEnabled should return false
+ featureJson =
+ """
+ {
+ "features": {
+ "autofill": {
+ "state": "disabled",
+ "exceptions": [],
+ "features": {
+ "credentialsSaving": {
+ "state": "enabled",
+ "minSupportedVersion": 2,
+ "cohorts": [
+ {
+ "name": "control",
+ "weight": 1
+ },
+ {
+ "name": "blue",
+ "weight": 0
+ }
+ ]
+ }
+ }
+ }
+ }
+ }
+ """.data(using: .utf8)!
+ manager.reload(etag: "", data: featureJson)
+
+ XCTAssertNil(featureFlagger.getCohortIfEnabled(for: CredentialsSavingFlag()))
+ XCTAssertFalse(mockStore.experiments?.isEmpty ?? true)
+ XCTAssertEqual(experimentManager.cohort(for: subfeatureName), "control")
+ }
+
+ func testCohortsAndTargetsInteraction() {
+ func featureJson(country: String, language: String) -> Data {
+ """
+ {
+ "features": {
+ "autofill": {
+ "state": "enabled",
+ "exceptions": [],
+ "features": {
+ "credentialsSaving": {
+ "state": "enabled",
+ "minSupportedVersion": 2,
+ "targets": [
+ {
+ "localeLanguage": "\(language)",
+ "localeCountry": "\(country)"
+ }
+ ],
+ "cohorts": [
+ {
+ "name": "control",
+ "weight": 1
+ },
+ {
+ "name": "blue",
+ "weight": 0
+ }
+ ]
+ }
+ }
+ }
+ }
+ }
+ """.data(using: .utf8)!
+ }
+ manager.reload(etag: "", data: featureJson(country: "FR", language: "fr"))
+
+ XCTAssertNil(featureFlagger.getCohortIfEnabled(for: CredentialsSavingFlag()))
+ XCTAssertNil(mockStore.experiments)
+ XCTAssertNil(experimentManager.cohort(for: subfeatureName))
+
+ manager.reload(etag: "", data: featureJson(country: "US", language: "en"))
+ XCTAssertNil(featureFlagger.getCohortIfEnabled(for: CredentialsSavingFlag()))
+ XCTAssertNil(mockStore.experiments)
+ XCTAssertNil(experimentManager.cohort(for: subfeatureName))
+
+ manager.reload(etag: "", data: featureJson(country: "US", language: "fr"))
+ XCTAssertEqual(featureFlagger.getCohortIfEnabled(for: CredentialsSavingFlag())?.rawValue, CredentialsSavingFlag.CohortType.control.rawValue)
+ XCTAssertFalse(mockStore.experiments?.isEmpty ?? true)
+ XCTAssertEqual(experimentManager.cohort(for: subfeatureName), "control")
+
+ // once cohort is assigned, changing targets shall not affect feature state
+ manager.reload(etag: "", data: featureJson(country: "IT", language: "it"))
+ XCTAssertEqual(featureFlagger.getCohortIfEnabled(for: CredentialsSavingFlag())?.rawValue, CredentialsSavingFlag.CohortType.control.rawValue)
+ XCTAssertFalse(mockStore.experiments?.isEmpty ?? true)
+ XCTAssertEqual(experimentManager.cohort(for: subfeatureName), "control")
+
+ let featureJson2 =
+ """
+ {
+ "features": {
+ "autofill": {
+ "state": "enabled",
+ "exceptions": [],
+ "features": {
+ "credentialsSaving": {
+ "state": "enabled",
+ "minSupportedVersion": 2,
+ "targets": [
+ {
+ "localeCountry": "FR"
+ }
+ ],
+ }
+ }
+ }
+ }
+ }
+ """.data(using: .utf8)!
+ manager.reload(etag: "", data: featureJson2)
+
+ XCTAssertNil(featureFlagger.getCohortIfEnabled(for: CredentialsSavingFlag()))
+ XCTAssertTrue(mockStore.experiments?.isEmpty ?? false)
+ XCTAssertNil(experimentManager.cohort(for: subfeatureName))
+
+ // re-populate experiment to re-assign new cohort, should not be assigned as it has wrong targets
+ manager.reload(etag: "", data: featureJson(country: "IT", language: "it"))
+ XCTAssertNil(featureFlagger.getCohortIfEnabled(for: CredentialsSavingFlag()))
+ XCTAssertTrue(mockStore.experiments?.isEmpty ?? false)
+ XCTAssertNil(experimentManager.cohort(for: subfeatureName))
+ }
+
+ func testChangeRemoteCohortsAfterAssignmentShouldNoop() {
+ featureJson =
+ """
+ {
+ "features": {
+ "autofill": {
+ "state": "enabled",
+ "exceptions": [],
+ "features": {
+ "credentialsSaving": {
+ "state": "enabled",
+ "minSupportedVersion": 2,
+ "targets": [
+ {
+ "localeCountry": "US"
+ }
+ ],
+ "cohorts": [
+ {
+ "name": "control",
+ "weight": 1
+ },
+ {
+ "name": "blue",
+ "weight": 0
+ }
+ ]
+ }
+ }
+ }
+ }
+ }
+ """.data(using: .utf8)!
+ manager.reload(etag: "", data: featureJson)
+
+ XCTAssertEqual(featureFlagger.getCohortIfEnabled(for: CredentialsSavingFlag())?.rawValue, CredentialsSavingFlag.CohortType.control.rawValue)
+ XCTAssertFalse(mockStore.experiments?.isEmpty ?? true)
+ XCTAssertEqual(experimentManager.cohort(for: subfeatureName), "control")
+
+ // changing targets should not change cohort assignment
+ featureJson =
+ """
+ {
+ "features": {
+ "autofill": {
+ "state": "enabled",
+ "exceptions": [],
+ "features": {
+ "credentialsSaving": {
+ "state": "enabled",
+ "minSupportedVersion": 2,
+ "targets": [
+ {
+ "localeCountry": "IT"
+ }
+ ],
+ "cohorts": [
+ {
+ "name": "control",
+ "weight": 1
+ },
+ {
+ "name": "blue",
+ "weight": 0
+ }
+ ]
+ }
+ }
+ }
+ }
+ }
+ """.data(using: .utf8)!
+ manager.reload(etag: "", data: featureJson)
+
+ XCTAssertEqual(featureFlagger.getCohortIfEnabled(for: CredentialsSavingFlag())?.rawValue, CredentialsSavingFlag.CohortType.control.rawValue)
+ XCTAssertFalse(mockStore.experiments?.isEmpty ?? true)
+ XCTAssertEqual(experimentManager.cohort(for: subfeatureName), "control")
+
+ // changing cohort weight should not change current assignment
+ featureJson =
+ """
+ {
+ "features": {
+ "autofill": {
+ "state": "enabled",
+ "exceptions": [],
+ "features": {
+ "credentialsSaving": {
+ "state": "enabled",
+ "minSupportedVersion": 2,
+ "targets": [
+ {
+ "localeCountry": "US"
+ }
+ ],
+ "cohorts": [
+ {
+ "name": "control",
+ "weight": 0
+ },
+ {
+ "name": "blue",
+ "weight": 1
+ }
+ ]
+ }
+ }
+ }
+ }
+ }
+ """.data(using: .utf8)!
+ manager.reload(etag: "", data: featureJson)
+
+ XCTAssertEqual(featureFlagger.getCohortIfEnabled(for: CredentialsSavingFlag())?.rawValue, CredentialsSavingFlag.CohortType.control.rawValue)
+ XCTAssertFalse(mockStore.experiments?.isEmpty ?? true)
+ XCTAssertEqual(experimentManager.cohort(for: subfeatureName), "control")
+
+ // adding cohorts should not change current assignment
+ featureJson =
+ """
+ {
+ "features": {
+ "autofill": {
+ "state": "enabled",
+ "exceptions": [],
+ "features": {
+ "credentialsSaving": {
+ "state": "enabled",
+ "minSupportedVersion": 2,
+ "targets": [
+ {
+ "localeCountry": "US"
+ }
+ ],
+ "cohorts": [
+ {
+ "name": "control",
+ "weight": 1
+ },
+ {
+ "name": "blue",
+ "weight": 1
+ },
+ {
+ "name": "red",
+ "weight": 1
+ }
+ ]
+ }
+ }
+ }
+ }
+ }
+ """.data(using: .utf8)!
+ manager.reload(etag: "", data: featureJson)
+
+ XCTAssertEqual(featureFlagger.getCohortIfEnabled(for: CredentialsSavingFlag())?.rawValue, CredentialsSavingFlag.CohortType.control.rawValue)
+ XCTAssertFalse(mockStore.experiments?.isEmpty ?? true)
+ XCTAssertEqual(experimentManager.cohort(for: subfeatureName), "control")
+
+ }
+
+ func testEnrollmentDate() {
+ featureJson =
+ """
+ {
+ "features": {
+ "autofill": {
+ "state": "enabled",
+ "exceptions": [],
+ "features": {
+ "credentialsSaving": {
+ "state": "enabled",
+ "minSupportedVersion": 2,
+ "targets": [
+ {
+ "localeCountry": "US"
+ }
+ ],
+ "cohorts": [
+ {
+ "name": "control",
+ "weight": 1
+ },
+ {
+ "name": "blue",
+ "weight": 0
+ }
+ ]
+ }
+ }
+ }
+ }
+ }
+ """.data(using: .utf8)!
+ manager.reload(etag: "", data: featureJson)
+ let config = manager.privacyConfig
+
+ XCTAssertTrue(config.isSubfeatureEnabled(AutofillSubfeature.credentialsSaving))
+ XCTAssertTrue(mockStore.experiments?.isEmpty ?? true)
+ XCTAssertNil(experimentManager.cohort(for: subfeatureName), "control")
+
+ XCTAssertEqual(featureFlagger.getCohortIfEnabled(for: CredentialsSavingFlag())?.rawValue, CredentialsSavingFlag.CohortType.control.rawValue)
+ let currentTime = Date().timeIntervalSince1970
+ let enrollmentTime = mockStore.experiments?[subfeatureName]?.enrollmentDate.timeIntervalSince1970
+
+ XCTAssertNotNil(enrollmentTime)
+ if let enrollmentTime = enrollmentTime {
+ let tolerance: TimeInterval = 60 // 1 minute in seconds
+ XCTAssertEqual(currentTime, enrollmentTime, accuracy: tolerance)
+ }
+ }
+
+ func testRollbackCohortExperiments() {
+ featureJson =
+ """
+ {
+ "features": {
+ "autofill": {
+ "state": "enabled",
+ "exceptions": [],
+ "features": {
+ "credentialsSaving": {
+ "state": "enabled",
+ "minSupportedVersion": 2,
+ "targets": [
+ {
+ "localeCountry": "US"
+ }
+ ],
+ "rollout": {
+ "steps": [
+ {
+ "percent": 100
+ }
+ ]
+ },
+ "cohorts": [
+ {
+ "name": "control",
+ "weight": 1
+ },
+ {
+ "name": "blue",
+ "weight": 0
+ }
+ ]
+ }
+ }
+ }
+ }
+ }
+ """.data(using: .utf8)!
+ manager.reload(etag: "foo", data: featureJson)
+ var config = manager.privacyConfig
+ clearRolloutData(feature: "autofill", subFeature: "credentialsSaving")
+
+ XCTAssertTrue(config.isSubfeatureEnabled(AutofillSubfeature.credentialsSaving))
+ XCTAssertEqual(featureFlagger.getCohortIfEnabled(for: CredentialsSavingFlag())?.rawValue, CredentialsSavingFlag.CohortType.control.rawValue)
+ XCTAssertFalse(mockStore.experiments?.isEmpty ?? true)
+ XCTAssertEqual(experimentManager.cohort(for: subfeatureName), "control")
+
+ featureJson =
+ """
+ {
+ "features": {
+ "autofill": {
+ "state": "enabled",
+ "exceptions": [],
+ "features": {
+ "credentialsSaving": {
+ "state": "enabled",
+ "minSupportedVersion": 2,
+ "targets": [
+ {
+ "localeCountry": "US"
+ }
+ ],
+ "rollout": {
+ "steps": [
+ {
+ "percent": 0
+ }
+ ]
+ },
+ "cohorts": [
+ {
+ "name": "control",
+ "weight": 1
+ },
+ {
+ "name": "blue",
+ "weight": 0
+ }
+ ]
+ }
+ }
+ }
+ }
+ }
+ """.data(using: .utf8)!
+ manager.reload(etag: "foo", data: featureJson)
+ config = manager.privacyConfig
+
+ XCTAssertTrue(config.isSubfeatureEnabled(AutofillSubfeature.credentialsSaving))
+ XCTAssertEqual(featureFlagger.getCohortIfEnabled(for: CredentialsSavingFlag())?.rawValue, CredentialsSavingFlag.CohortType.control.rawValue)
+ XCTAssertFalse(mockStore.experiments?.isEmpty ?? true)
+ XCTAssertEqual(experimentManager.cohort(for: subfeatureName), "control")
+
+ }
+
+ func testCohortEnabledAndStopEnrollmentAndRhenRollBack() {
+ featureJson =
+ """
+ {
+ "features": {
+ "autofill": {
+ "state": "enabled",
+ "exceptions": [],
+ "features": {
+ "credentialsSaving": {
+ "state": "enabled",
+ "minSupportedVersion": 2,
+ "targets": [
+ {
+ "localeCountry": "US"
+ }
+ ],
+ "cohorts": [
+ {
+ "name": "control",
+ "weight": 1
+ },
+ {
+ "name": "blue",
+ "weight": 0
+ }
+ ]
+ }
+ }
+ }
+ }
+ }
+ """.data(using: .utf8)!
+ manager.reload(etag: "foo", data: featureJson)
+ var config = manager.privacyConfig
+
+ XCTAssertTrue(config.isSubfeatureEnabled(AutofillSubfeature.credentialsSaving))
+ XCTAssertEqual(featureFlagger.getCohortIfEnabled(for: CredentialsSavingFlag())?.rawValue, CredentialsSavingFlag.CohortType.control.rawValue)
+ XCTAssertFalse(mockStore.experiments?.isEmpty ?? true)
+ XCTAssertEqual(experimentManager.cohort(for: subfeatureName), "control")
+
+ // Stop enrollment, should keep assigned cohorts
+ featureJson =
+ """
+ {
+ "features": {
+ "autofill": {
+ "state": "enabled",
+ "exceptions": [],
+ "features": {
+ "credentialsSaving": {
+ "state": "enabled",
+ "minSupportedVersion": 2,
+ "targets": [
+ {
+ "localeCountry": "US"
+ }
+ ],
+ "cohorts": [
+ {
+ "name": "control",
+ "weight": 0
+ },
+ {
+ "name": "blue",
+ "weight": 1
+ }
+ ]
+ }
+ }
+ }
+ }
+ }
+ """.data(using: .utf8)!
+ manager.reload(etag: "foo", data: featureJson)
+ config = manager.privacyConfig
+
+ XCTAssertTrue(config.isSubfeatureEnabled(AutofillSubfeature.credentialsSaving))
+ XCTAssertEqual(featureFlagger.getCohortIfEnabled(for: CredentialsSavingFlag())?.rawValue, CredentialsSavingFlag.CohortType.control.rawValue)
+ XCTAssertFalse(mockStore.experiments?.isEmpty ?? true)
+ XCTAssertEqual(experimentManager.cohort(for: subfeatureName), "control")
+
+ // remove control, should re-allocate to blue
+ featureJson =
+ """
+ {
+ "features": {
+ "autofill": {
+ "state": "enabled",
+ "exceptions": [],
+ "features": {
+ "credentialsSaving": {
+ "state": "enabled",
+ "minSupportedVersion": 2,
+ "targets": [
+ {
+ "localeCountry": "US"
+ }
+ ],
+ "cohorts": [
+ {
+ "name": "blue",
+ "weight": 1
+ }
+ ]
+ }
+ }
+ }
+ }
+ }
+ """.data(using: .utf8)!
+ manager.reload(etag: "foo", data: featureJson)
+ config = manager.privacyConfig
+
+ XCTAssertTrue(config.isSubfeatureEnabled(AutofillSubfeature.credentialsSaving))
+ XCTAssertEqual(featureFlagger.getCohortIfEnabled(for: CredentialsSavingFlag())?.rawValue, CredentialsSavingFlag.CohortType.blue.rawValue)
+ XCTAssertFalse(mockStore.experiments?.isEmpty ?? true)
+ XCTAssertEqual(experimentManager.cohort(for: subfeatureName), "blue")
+ }
+
+ private func clearRolloutData(feature: String, subFeature: String) {
+ UserDefaults().set(nil, forKey: "config.\(feature).\(subFeature).enabled")
+ UserDefaults().set(nil, forKey: "config.\(feature).\(subFeature).lastRolloutCount")
+ }
+
+ func testAllActiveExperimentsEmptyIfNoAssignedExperiment() {
+ featureJson =
+ """
+ {
+ "features": {
+ "autofill": {
+ "state": "enabled",
+ "exceptions": [],
+ "features": {
+ "credentialsSaving": {
+ "state": "enabled",
+ "minSupportedVersion": 2,
+ "targets": [
+ {
+ "localeCountry": "US"
+ }
+ ],
+ "cohorts": [
+ {
+ "name": "control",
+ "weight": 1
+ },
+ {
+ "name": "blue",
+ "weight": 0
+ }
+ ]
+ }
+ }
+ }
+ }
+ }
+ """.data(using: .utf8)!
+ manager.reload(etag: "foo", data: featureJson)
+
+ let activeExperiments = featureFlagger.getAllActiveExperiments()
+ XCTAssertTrue(activeExperiments.isEmpty)
+ XCTAssertNil(mockStore.experiments)
+ }
+
+ func testAllActiveExperimentsReturnsOnlyActiveExperiments() {
+ var featureJson =
+ """
+ {
+ "features": {
+ "autofill": {
+ "state": "enabled",
+ "exceptions": [],
+ "features": {
+ "credentialsSaving": {
+ "state": "enabled",
+ "minSupportedVersion": 2,
+ "targets": [
+ {
+ "localeCountry": "US"
+ }
+ ],
+ "cohorts": [
+ {
+ "name": "control",
+ "weight": 1
+ },
+ {
+ "name": "blue",
+ "weight": 0
+ }
+ ]
+ },
+ "inlineIconCredentials": {
+ "state": "enabled",
+ "minSupportedVersion": 1,
+ "targets": [
+ {
+ "localeCountry": "US"
+ }
+ ],
+ "cohorts": [
+ {
+ "name": "control",
+ "weight": 0
+ },
+ {
+ "name": "green",
+ "weight": 1
+ }
+ ]
+ },
+ "accessCredentialManagement": {
+ "state": "enabled",
+ "minSupportedVersion": 3,
+ "targets": [
+ {
+ "localeCountry": "CA"
+ }
+ ],
+ "cohorts": [
+ {
+ "name": "control",
+ "weight": 1
+ },
+ {
+ "name": "green",
+ "weight": 0
+ }
+ ]
+ }
+ }
+ }
+ }
+ }
+ """.data(using: .utf8)!
+
+ manager.reload(etag: "foo", data: featureJson)
+
+ XCTAssertEqual(featureFlagger.getCohortIfEnabled(for: CredentialsSavingFlag())?.rawValue, CredentialsSavingFlag.Cohort.control.rawValue)
+ XCTAssertEqual(featureFlagger.getCohortIfEnabled(for: InlineIconCredentialsFlag())?.rawValue, InlineIconCredentialsFlag.Cohort.green.rawValue)
+ XCTAssertNil(featureFlagger.getCohortIfEnabled(for: AccessCredentialManagementFlag()))
+
+ XCTAssertFalse(mockStore.experiments?.isEmpty ?? true)
+ XCTAssertEqual(experimentManager.cohort(for: AutofillSubfeature.credentialsSaving.rawValue), "control")
+ XCTAssertEqual(experimentManager.cohort(for: AutofillSubfeature.inlineIconCredentials.rawValue), "green")
+ XCTAssertNil(experimentManager.cohort(for: AutofillSubfeature.accessCredentialManagement.rawValue))
+ XCTAssertFalse(mockStore.experiments?.isEmpty ?? true)
+
+ var activeExperiments = featureFlagger.getAllActiveExperiments()
+ XCTAssertEqual(activeExperiments.count, 2)
+ XCTAssertEqual(activeExperiments[AutofillSubfeature.credentialsSaving.rawValue]?.cohortID, "control")
+ XCTAssertEqual(activeExperiments[AutofillSubfeature.inlineIconCredentials.rawValue]?.cohortID, "green")
+ XCTAssertNil(activeExperiments[AutofillSubfeature.accessCredentialManagement.rawValue])
+
+ // When an assigned cohort is removed it's not part of active experiments
+ featureJson =
+ """
+ {
+ "features": {
+ "autofill": {
+ "state": "enabled",
+ "exceptions": [],
+ "features": {
+ "credentialsSaving": {
+ "state": "enabled",
+ "minSupportedVersion": 2,
+ "targets": [
+ {
+ "localeCountry": "US"
+ }
+ ],
+ "cohorts": [
+ {
+ "name": "blue",
+ "weight": 1
+ }
+ ]
+ },
+ "inlineIconCredentials": {
+ "state": "enabled",
+ "minSupportedVersion": 1,
+ "targets": [
+ {
+ "localeCountry": "US"
+ }
+ ],
+ "cohorts": [
+ {
+ "name": "control",
+ "weight": 0
+ },
+ {
+ "name": "green",
+ "weight": 1
+ }
+ ]
+ },
+ "accessCredentialManagement": {
+ "state": "enabled",
+ "minSupportedVersion": 3,
+ "targets": [
+ {
+ "localeCountry": "CA"
+ }
+ ],
+ "cohorts": [
+ {
+ "name": "control",
+ "weight": 1
+ },
+ {
+ "name": "green",
+ "weight": 0
+ }
+ ]
+ }
+ }
+ }
+ }
+ }
+ """.data(using: .utf8)!
+
+ manager.reload(etag: "foo", data: featureJson)
+
+ activeExperiments = featureFlagger.getAllActiveExperiments()
+ XCTAssertEqual(activeExperiments.count, 1)
+ XCTAssertNil(activeExperiments[AutofillSubfeature.credentialsSaving.rawValue])
+ XCTAssertEqual(activeExperiments[AutofillSubfeature.inlineIconCredentials.rawValue]?.cohortID, "green")
+ XCTAssertNil(activeExperiments[AutofillSubfeature.accessCredentialManagement.rawValue])
+
+ // When feature disabled an assigned cohort it's not part of active experiments
+ featureJson =
+ """
+ {
+ "features": {
+ "autofill": {
+ "state": "enabled",
+ "exceptions": [],
+ "features": {
+ "credentialsSaving": {
+ "state": "enabled",
+ "minSupportedVersion": 2,
+ "targets": [
+ {
+ "localeCountry": "US"
+ }
+ ],
+ "cohorts": [
+ {
+ "name": "blue",
+ "weight": 1
+ }
+ ]
+ },
+ "inlineIconCredentials": {
+ "state": "disabled",
+ "minSupportedVersion": 1,
+ "targets": [
+ {
+ "localeCountry": "US"
+ }
+ ],
+ "cohorts": [
+ {
+ "name": "control",
+ "weight": 0
+ },
+ {
+ "name": "green",
+ "weight": 1
+ }
+ ]
+ },
+ "accessCredentialManagement": {
+ "state": "enabled",
+ "minSupportedVersion": 3,
+ "targets": [
+ {
+ "localeCountry": "CA"
+ }
+ ],
+ "cohorts": [
+ {
+ "name": "control",
+ "weight": 1
+ },
+ {
+ "name": "green",
+ "weight": 0
+ }
+ ]
+ }
+ }
+ }
+ }
+ }
+ """.data(using: .utf8)!
+
+ manager.reload(etag: "foo", data: featureJson)
+
+ activeExperiments = featureFlagger.getAllActiveExperiments()
+ XCTAssertTrue(activeExperiments.isEmpty)
+ }
+
+}
diff --git a/Tests/BrowserServicesKitTests/FeatureFlagging/TestFeatureFlag.swift b/Tests/BrowserServicesKitTests/FeatureFlagging/TestFeatureFlag.swift
index 8c1a43ee7..8d8ec9929 100644
--- a/Tests/BrowserServicesKitTests/FeatureFlagging/TestFeatureFlag.swift
+++ b/Tests/BrowserServicesKitTests/FeatureFlagging/TestFeatureFlag.swift
@@ -35,11 +35,11 @@ enum TestFeatureFlag: String, FeatureFlagDescribing {
var source: FeatureFlagSource {
switch self {
case .nonOverridableFlag:
- return .internalOnly
+ return .internalOnly()
case .overridableFlagDisabledByDefault:
return .disabled
case .overridableFlagEnabledByDefault:
- return .internalOnly
+ return .internalOnly()
}
}
}
diff --git a/Tests/BrowserServicesKitTests/PrivacyConfig/AppPrivacyConfigurationTests.swift b/Tests/BrowserServicesKitTests/PrivacyConfig/AppPrivacyConfigurationTests.swift
index 252a7b866..9388b14e9 100644
--- a/Tests/BrowserServicesKitTests/PrivacyConfig/AppPrivacyConfigurationTests.swift
+++ b/Tests/BrowserServicesKitTests/PrivacyConfig/AppPrivacyConfigurationTests.swift
@@ -907,6 +907,186 @@ class AppPrivacyConfigurationTests: XCTestCase {
XCTAssertEqual(configAfterUpdate.stateFor(AutofillSubfeature.credentialsSaving, randomizer: mockRandom(in:)), .disabled(.disabledInConfig))
}
+ let exampleSubfeatureEnabledWithTarget =
+ """
+ {
+ "features": {
+ "autofill": {
+ "state": "enabled",
+ "exceptions": [],
+ "features": {
+ "credentialsSaving": {
+ "state": "enabled",
+ "targets": [
+ {
+ "localeCountry": "US",
+ "localeLanguage": "fr"
+ }
+ ]
+ }
+ }
+ }
+ }
+ }
+ """.data(using: .utf8)!
+
+ func testWhenCheckingSubfeatureStateWithSubfeatureEnabledWhenTargetMatches_SubfeatureShouldBeEnabled() {
+ let mockEmbeddedData = MockEmbeddedDataProvider(data: exampleSubfeatureEnabledWithTarget, etag: "test")
+ let mockInternalUserStore = MockInternalUserStoring()
+ let locale = Locale(identifier: "fr_US")
+
+ let manager = PrivacyConfigurationManager(fetchedETag: nil,
+ fetchedData: nil,
+ embeddedDataProvider: mockEmbeddedData,
+ localProtection: MockDomainsProtectionStore(),
+ internalUserDecider: DefaultInternalUserDecider(store: mockInternalUserStore),
+ locale: locale)
+ let config = manager.privacyConfig
+
+ XCTAssertTrue(config.isSubfeatureEnabled(AutofillSubfeature.credentialsSaving))
+ XCTAssertEqual(config.stateFor(AutofillSubfeature.credentialsSaving), .enabled)
+ }
+
+ func testWhenCheckingSubfeatureStateWithSubfeatureEnabledWhenRegionDoesNotMatches_SubfeatureShouldBeDisabled() {
+ let mockEmbeddedData = MockEmbeddedDataProvider(data: exampleSubfeatureEnabledWithTarget, etag: "test")
+ let mockInternalUserStore = MockInternalUserStoring()
+ let locale = Locale(identifier: "fr_FR")
+
+ let manager = PrivacyConfigurationManager(fetchedETag: nil,
+ fetchedData: nil,
+ embeddedDataProvider: mockEmbeddedData,
+ localProtection: MockDomainsProtectionStore(),
+ internalUserDecider: DefaultInternalUserDecider(store: mockInternalUserStore),
+ locale: locale)
+ let config = manager.privacyConfig
+
+ XCTAssertFalse(config.isSubfeatureEnabled(AutofillSubfeature.credentialsSaving))
+ XCTAssertEqual(config.stateFor(AutofillSubfeature.credentialsSaving), .disabled(.targetDoesNotMatch))
+ }
+
+ func testWhenCheckingSubfeatureStateWithSubfeatureEnabledWhenLanguageDoesNotMatches_SubfeatureShouldBeDisabled() {
+ let mockEmbeddedData = MockEmbeddedDataProvider(data: exampleSubfeatureEnabledWithTarget, etag: "test")
+ let mockInternalUserStore = MockInternalUserStoring()
+ let locale = Locale(identifier: "it_US")
+
+ let manager = PrivacyConfigurationManager(fetchedETag: nil,
+ fetchedData: nil,
+ embeddedDataProvider: mockEmbeddedData,
+ localProtection: MockDomainsProtectionStore(),
+ internalUserDecider: DefaultInternalUserDecider(store: mockInternalUserStore),
+ locale: locale)
+ let config = manager.privacyConfig
+
+ XCTAssertFalse(config.isSubfeatureEnabled(AutofillSubfeature.credentialsSaving))
+ XCTAssertEqual(config.stateFor(AutofillSubfeature.credentialsSaving), .disabled(.targetDoesNotMatch))
+ }
+
+ let exampleSubfeatureDisabledWithTarget =
+ """
+ {
+ "features": {
+ "autofill": {
+ "state": "enabled",
+ "exceptions": [],
+ "features": {
+ "credentialsSaving": {
+ "state": "disabled",
+ "targets": [
+ {
+ "localeCountry": "US",
+ "localeLanguage": "fr"
+ }
+ ]
+ }
+ }
+ }
+ }
+ }
+ """.data(using: .utf8)!
+
+ func testWhenCheckingSubfeatureStateWithSubfeatureDisabledWhenTargetMatches_SubfeatureShouldBeDisabled() {
+ let mockEmbeddedData = MockEmbeddedDataProvider(data: exampleSubfeatureDisabledWithTarget, etag: "test")
+ let mockInternalUserStore = MockInternalUserStoring()
+ let locale = Locale(identifier: "fr_US")
+
+ let manager = PrivacyConfigurationManager(fetchedETag: nil,
+ fetchedData: nil,
+ embeddedDataProvider: mockEmbeddedData,
+ localProtection: MockDomainsProtectionStore(),
+ internalUserDecider: DefaultInternalUserDecider(store: mockInternalUserStore),
+ locale: locale)
+ let config = manager.privacyConfig
+
+ XCTAssertFalse(config.isSubfeatureEnabled(AutofillSubfeature.credentialsSaving))
+ XCTAssertEqual(config.stateFor(AutofillSubfeature.credentialsSaving), .disabled(.disabledInConfig))
+ }
+
+ let exampleSubfeatureEnabledWithRolloutAndTarget =
+ """
+ {
+ "features": {
+ "autofill": {
+ "state": "enabled",
+ "exceptions": [],
+ "features": {
+ "credentialsSaving": {
+ "state": "enabled",
+ "targets": [
+ {
+ "localeCountry": "US",
+ "localeLanguage": "fr"
+ }
+ ],
+ "rollout": {
+ "steps": [{
+ "percent": 5.0
+ }]
+ }
+ }
+ }
+ }
+ }
+ }
+ """.data(using: .utf8)!
+
+ func testWhenCheckingSubfeatureStateWithSubfeatureEnabledAndTargetMatchesWhenNotInRollout_SubfeatureShouldBeDisabled() {
+ let mockEmbeddedData = MockEmbeddedDataProvider(data: exampleSubfeatureEnabledWithRolloutAndTarget, etag: "test")
+ let mockInternalUserStore = MockInternalUserStoring()
+ let locale = Locale(identifier: "fr_US")
+ mockRandomValue = 7.0
+ clearRolloutData(feature: "autofill", subFeature: "credentialsSaving")
+
+ let manager = PrivacyConfigurationManager(fetchedETag: nil,
+ fetchedData: nil,
+ embeddedDataProvider: mockEmbeddedData,
+ localProtection: MockDomainsProtectionStore(),
+ internalUserDecider: DefaultInternalUserDecider(store: mockInternalUserStore),
+ locale: locale)
+ let config = manager.privacyConfig
+
+ XCTAssertFalse(config.isSubfeatureEnabled(AutofillSubfeature.credentialsSaving, randomizer: mockRandom(in:)))
+ XCTAssertEqual(config.stateFor(AutofillSubfeature.credentialsSaving), .disabled(.stillInRollout))
+ }
+
+ func testWhenCheckingSubfeatureStateWithSubfeatureEnabledAndTargetMatchesWhenInRollout_SubfeatureShouldBeEnabled() {
+ let mockEmbeddedData = MockEmbeddedDataProvider(data: exampleSubfeatureEnabledWithRolloutAndTarget, etag: "test")
+ let mockInternalUserStore = MockInternalUserStoring()
+ let locale = Locale(identifier: "fr_US")
+ mockRandomValue = 2.0
+ clearRolloutData(feature: "autofill", subFeature: "credentialsSaving")
+
+ let manager = PrivacyConfigurationManager(fetchedETag: nil,
+ fetchedData: nil,
+ embeddedDataProvider: mockEmbeddedData,
+ localProtection: MockDomainsProtectionStore(),
+ internalUserDecider: DefaultInternalUserDecider(store: mockInternalUserStore),
+ locale: locale)
+ let config = manager.privacyConfig
+
+ XCTAssertTrue(config.isSubfeatureEnabled(AutofillSubfeature.credentialsSaving, randomizer: mockRandom(in:)))
+ XCTAssertEqual(config.stateFor(AutofillSubfeature.credentialsSaving), .enabled)
+ }
+
let exampleEnabledSubfeatureWithRollout =
"""
{
diff --git a/Tests/BrowserServicesKitTests/PrivacyConfig/ExperimentCohortsManagerTests.swift b/Tests/BrowserServicesKitTests/PrivacyConfig/ExperimentCohortsManagerTests.swift
deleted file mode 100644
index 518249560..000000000
--- a/Tests/BrowserServicesKitTests/PrivacyConfig/ExperimentCohortsManagerTests.swift
+++ /dev/null
@@ -1,266 +0,0 @@
-//
-// ExperimentCohortsManagerTests.swift
-//
-// Copyright © 2024 DuckDuckGo. All rights reserved.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-//
-
-import XCTest
-@testable import BrowserServicesKit
-
-final class ExperimentCohortsManagerTests: XCTestCase {
-
- var mockStore: MockExperimentDataStore!
- var experimentCohortsManager: ExperimentCohortsManager!
-
- let subfeatureName1 = "TestSubfeature1"
- var experimentData1: ExperimentData!
-
- let subfeatureName2 = "TestSubfeature2"
- var experimentData2: ExperimentData!
-
- let encoder: JSONEncoder = {
- let encoder = JSONEncoder()
- encoder.dateEncodingStrategy = .secondsSince1970
- return encoder
- }()
-
- override func setUp() {
- super.setUp()
- mockStore = MockExperimentDataStore()
- experimentCohortsManager = ExperimentCohortsManager(
- store: mockStore,
- randomizer: { _ in 50.0 }
- )
-
- let expectedDate1 = Date()
- experimentData1 = ExperimentData(cohort: "TestCohort1", enrollmentDate: expectedDate1)
-
- let expectedDate2 = Date().addingTimeInterval(60)
- experimentData2 = ExperimentData(cohort: "TestCohort2", enrollmentDate: expectedDate2)
- }
-
- override func tearDown() {
- mockStore = nil
- experimentCohortsManager = nil
- experimentData1 = nil
- experimentData2 = nil
- super.tearDown()
- }
-
- func testCohortReturnsCohortIDIfExistsForMultipleSubfeatures() {
- // GIVEN
- mockStore.experiments = [subfeatureName1: experimentData1, subfeatureName2: experimentData2]
-
- // WHEN
- let result1 = experimentCohortsManager.cohort(for: subfeatureName1)
- let result2 = experimentCohortsManager.cohort(for: subfeatureName2)
-
- // THEN
- XCTAssertEqual(result1, experimentData1.cohort)
- XCTAssertEqual(result2, experimentData2.cohort)
- }
-
- func testEnrollmentDateReturnsCorrectDateIfExists() {
- // GIVEN
- mockStore.experiments = [subfeatureName1: experimentData1]
-
- // WHEN
- let result1 = experimentCohortsManager.enrollmentDate(for: subfeatureName1)
- let result2 = experimentCohortsManager.enrollmentDate(for: subfeatureName2)
-
- // THEN
- let timeDifference1 = abs(experimentData1.enrollmentDate.timeIntervalSince(result1 ?? Date()))
-
- XCTAssertLessThanOrEqual(timeDifference1, 1.0, "Expected enrollment date for subfeatureName1 to match at the second level")
- XCTAssertNil(result2)
- }
-
- func testCohortReturnsNilIfCohortDoesNotExist() {
- // GIVEN
- let subfeatureName = "TestSubfeature"
-
- // WHEN
- let result = experimentCohortsManager.cohort(for: subfeatureName)
-
- // THEN
- XCTAssertNil(result)
- }
-
- func testEnrollmentDateReturnsNilIfDateDoesNotExist() {
- // GIVEN
- let subfeatureName = "TestSubfeature"
-
- // WHEN
- let result = experimentCohortsManager.enrollmentDate(for: subfeatureName)
-
- // THEN
- XCTAssertNil(result)
- }
-
- func testRemoveCohortSuccessfullyRemovesData() throws {
- // GIVEN
- mockStore.experiments = [subfeatureName1: experimentData1]
-
- // WHEN
- experimentCohortsManager.removeCohort(from: subfeatureName1)
-
- // THEN
- let experiments = try XCTUnwrap(mockStore.experiments)
- XCTAssertTrue(experiments.isEmpty)
- }
-
- func testRemoveCohortDoesNothingIfSubfeatureDoesNotExist() {
- // GIVEN
- let expectedExperiments: Experiments = [subfeatureName1: experimentData1, subfeatureName2: experimentData2]
- mockStore.experiments = expectedExperiments
-
- // WHEN
- experimentCohortsManager.removeCohort(from: "someOtherSubfeature")
-
- // THEN
- XCTAssertEqual( mockStore.experiments, expectedExperiments)
- }
-
- func testAssignCohortReturnsNilIfNoCohorts() {
- // GIVEN
- let subfeature = ExperimentSubfeature(subfeatureID: subfeatureName1, cohorts: [])
-
- // WHEN
- let result = experimentCohortsManager.assignCohort(to: subfeature)
-
- // THEN
- XCTAssertNil(result)
- }
-
- func testAssignCohortReturnsNilIfAllWeightsAreZero() {
- // GIVEN
- let jsonCohort1: [String: Any] = ["name": "TestCohort", "weight": 0]
- let jsonCohort2: [String: Any] = ["name": "TestCohort", "weight": 0]
- let cohorts = [
- PrivacyConfigurationData.Cohort(json: jsonCohort1)!,
- PrivacyConfigurationData.Cohort(json: jsonCohort2)!
- ]
- let subfeature = ExperimentSubfeature(subfeatureID: subfeatureName1, cohorts: cohorts)
-
- // WHEN
- let result = experimentCohortsManager.assignCohort(to: subfeature)
-
- // THEN
- XCTAssertNil(result)
- }
-
- func testAssignCohortSelectsCorrectCohortBasedOnWeight() {
- // Cohort1 has weight 1, Cohort2 has weight 3
- // Total weight is 1 + 3 = 4
- let jsonCohort1: [String: Any] = ["name": "Cohort1", "weight": 1]
- let jsonCohort2: [String: Any] = ["name": "Cohort2", "weight": 3]
- let cohorts = [
- PrivacyConfigurationData.Cohort(json: jsonCohort1)!,
- PrivacyConfigurationData.Cohort(json: jsonCohort2)!
- ]
- let subfeature = ExperimentSubfeature(subfeatureID: subfeatureName1, cohorts: cohorts)
- let expectedTotalWeight = 4.0
-
- // Use a custom randomizer to verify the range
- experimentCohortsManager = ExperimentCohortsManager(
- store: mockStore,
- randomizer: { range in
- // Assert that the range lower bound is 0
- XCTAssertEqual(range.lowerBound, 0.0)
- // Assert that the range upper bound is the total weight
- XCTAssertEqual(range.upperBound, expectedTotalWeight)
- return 0.0
- }
- )
-
- // Test case where random value is at the very start of Cohort1's range (0)
- experimentCohortsManager = ExperimentCohortsManager(
- store: mockStore,
- randomizer: { _ in 0.0 }
- )
- let resultStartOfCohort1 = experimentCohortsManager.assignCohort(to: subfeature)
- XCTAssertEqual(resultStartOfCohort1, "Cohort1")
-
- // Test case where random value is at the end of Cohort1's range (0.9)
- experimentCohortsManager = ExperimentCohortsManager(
- store: mockStore,
- randomizer: { _ in 0.9 }
- )
- let resultEndOfCohort1 = experimentCohortsManager.assignCohort(to: subfeature)
- XCTAssertEqual(resultEndOfCohort1, "Cohort1")
-
- // Test case where random value is at the start of Cohort2's range (1.00 to 4)
- experimentCohortsManager = ExperimentCohortsManager(
- store: mockStore,
- randomizer: { _ in 1.00 }
- )
- let resultStartOfCohort2 = experimentCohortsManager.assignCohort(to: subfeature)
- XCTAssertEqual(resultStartOfCohort2, "Cohort2")
-
- // Test case where random value falls exactly within Cohort2's range (2.5)
- experimentCohortsManager = ExperimentCohortsManager(
- store: mockStore,
- randomizer: { _ in 2.5 }
- )
- let resultMiddleOfCohort2 = experimentCohortsManager.assignCohort(to: subfeature)
- XCTAssertEqual(resultMiddleOfCohort2, "Cohort2")
-
- // Test case where random value is at the end of Cohort2's range (4)
- experimentCohortsManager = ExperimentCohortsManager(
- store: mockStore,
- randomizer: { _ in 3.9 }
- )
- let resultEndOfCohort2 = experimentCohortsManager.assignCohort(to: subfeature)
- XCTAssertEqual(resultEndOfCohort2, "Cohort2")
- }
-
- func testAssignCohortWithSingleCohortAlwaysSelectsThatCohort() throws {
- // GIVEN
- let jsonCohort1: [String: Any] = ["name": "Cohort1", "weight": 1]
- let cohorts = [
- PrivacyConfigurationData.Cohort(json: jsonCohort1)!
- ]
- let subfeature = ExperimentSubfeature(subfeatureID: subfeatureName1, cohorts: cohorts)
- let expectedTotalWeight = 1.0
-
- // Use a custom randomizer to verify the range
- experimentCohortsManager = ExperimentCohortsManager(
- store: mockStore,
- randomizer: { range in
- // Assert that the range lower bound is 0
- XCTAssertEqual(range.lowerBound, 0.0)
- // Assert that the range upper bound is the total weight
- XCTAssertEqual(range.upperBound, expectedTotalWeight)
- return 0.0
- }
- )
-
- // WHEN
- experimentCohortsManager = ExperimentCohortsManager(
- store: mockStore,
- randomizer: { range in Double.random(in: range)}
- )
- let result = experimentCohortsManager.assignCohort(to: subfeature)
-
- // THEN
- XCTAssertEqual(result, "Cohort1")
- XCTAssertEqual(cohorts[0].name, mockStore.experiments?[subfeature.subfeatureID]?.cohort)
- }
-
-}
-
-class MockExperimentDataStore: ExperimentsDataStoring {
- var experiments: Experiments?
-}
diff --git a/Tests/BrowserServicesKitTests/PrivacyConfig/PrivacyConfigurationDataTests.swift b/Tests/BrowserServicesKitTests/PrivacyConfig/PrivacyConfigurationDataTests.swift
index fac814b1c..08ecbec85 100644
--- a/Tests/BrowserServicesKitTests/PrivacyConfig/PrivacyConfigurationDataTests.swift
+++ b/Tests/BrowserServicesKitTests/PrivacyConfig/PrivacyConfigurationDataTests.swift
@@ -68,6 +68,9 @@ class PrivacyConfigurationDataTests: XCTestCase {
XCTAssertEqual(subfeatures["enabledSubfeature"]?.cohorts?.count, 3)
XCTAssertEqual(subfeatures["enabledSubfeature"]?.cohorts?[0].name, "myExperimentControl")
XCTAssertEqual(subfeatures["enabledSubfeature"]?.cohorts?[0].weight, 1)
+ XCTAssertEqual(subfeatures["enabledSubfeature"]?.targets?[0].localeCountry, "US")
+ XCTAssertEqual(subfeatures["enabledSubfeature"]?.targets?[0].localeLanguage, "fr")
+ XCTAssertEqual(subfeatures["enabledSubfeature"]?.settings, "{\"foo\":\"foo\\/value\",\"bar\":\"bar\\/value\"}")
XCTAssertEqual(subfeatures["internalSubfeature"]?.state, "internal")
} else {
XCTFail("Could not parse subfeatures")
diff --git a/Tests/BrowserServicesKitTests/Resources/privacy-config-example.json b/Tests/BrowserServicesKitTests/Resources/privacy-config-example.json
index 3fd5be5a5..735c3914f 100644
--- a/Tests/BrowserServicesKitTests/Resources/privacy-config-example.json
+++ b/Tests/BrowserServicesKitTests/Resources/privacy-config-example.json
@@ -171,6 +171,16 @@
},
"enabledSubfeature": {
"state": "enabled",
+ "targets": [
+ {
+ "localeCountry": "US",
+ "localeLanguage": "fr"
+ }
+ ],
+ "settings": {
+ "foo": "foo/value",
+ "bar": "bar/value"
+ },
"description": "A description of the sub-feature",
"cohorts": [
{
diff --git a/Tests/BrowserServicesKitTests/Resources/privacy-reference-tests b/Tests/BrowserServicesKitTests/Resources/privacy-reference-tests
index 6133e7d9d..a603ff9af 160000
--- a/Tests/BrowserServicesKitTests/Resources/privacy-reference-tests
+++ b/Tests/BrowserServicesKitTests/Resources/privacy-reference-tests
@@ -1 +1 @@
-Subproject commit 6133e7d9d9cd5f1b925cab1971b4d785dc639df7
+Subproject commit a603ff9af22ca3ff7ce2e7ffbfe18c447d9f23e8
diff --git a/Tests/BrowserServicesKitTests/Subscription/SubscriptionFeatureAvailabilityTests.swift b/Tests/BrowserServicesKitTests/Subscription/SubscriptionFeatureAvailabilityTests.swift
index a85dadc1e..388ad69b9 100644
--- a/Tests/BrowserServicesKitTests/Subscription/SubscriptionFeatureAvailabilityTests.swift
+++ b/Tests/BrowserServicesKitTests/Subscription/SubscriptionFeatureAvailabilityTests.swift
@@ -221,17 +221,29 @@ class MockPrivacyConfiguration: PrivacyConfiguration {
var isSubfeatureEnabledCheck: ((any PrivacySubfeature) -> Bool)?
- func isSubfeatureEnabled(_ subfeature: any PrivacySubfeature, versionProvider: AppVersionProvider, randomizer: (Range) -> Double) -> Bool {
+ func isSubfeatureEnabled(_ subfeature: any BrowserServicesKit.PrivacySubfeature, versionProvider: BrowserServicesKit.AppVersionProvider, randomizer: (Range) -> Double) -> Bool {
isSubfeatureEnabledCheck?(subfeature) ?? false
}
- func stateFor(_ subfeature: any PrivacySubfeature, versionProvider: AppVersionProvider, randomizer: (Range) -> Double) -> PrivacyConfigurationFeatureState {
+ func stateFor(_ subfeature: any BrowserServicesKit.PrivacySubfeature, versionProvider: BrowserServicesKit.AppVersionProvider, randomizer: (Range) -> Double) -> BrowserServicesKit.PrivacyConfigurationFeatureState {
if isSubfeatureEnabledCheck?(subfeature) == true {
return .enabled
}
return .disabled(.disabledInConfig)
}
+ func stateFor(subfeatureID: BrowserServicesKit.SubfeatureID, parentFeatureID: BrowserServicesKit.ParentFeatureID, versionProvider: BrowserServicesKit.AppVersionProvider, randomizer: (Range) -> Double) -> BrowserServicesKit.PrivacyConfigurationFeatureState {
+ return .enabled
+ }
+
+ func cohorts(for subfeature: any BrowserServicesKit.PrivacySubfeature) -> [BrowserServicesKit.PrivacyConfigurationData.Cohort]? {
+ return nil
+ }
+
+ func cohorts(subfeatureID: BrowserServicesKit.SubfeatureID, parentFeatureID: BrowserServicesKit.ParentFeatureID) -> [BrowserServicesKit.PrivacyConfigurationData.Cohort]? {
+ return nil
+ }
+
var identifier: String = "abcd"
var version: String? = "123456789"
var userUnprotectedDomains: [String] = []
diff --git a/Tests/CommonTests/Extensions/StringExtensionTests.swift b/Tests/CommonTests/Extensions/StringExtensionTests.swift
index bcf415895..65abb79c8 100644
--- a/Tests/CommonTests/Extensions/StringExtensionTests.swift
+++ b/Tests/CommonTests/Extensions/StringExtensionTests.swift
@@ -16,8 +16,10 @@
// limitations under the License.
//
+import CryptoKit
import Foundation
import XCTest
+
@testable import Common
final class StringExtensionTests: XCTestCase {
@@ -370,4 +372,13 @@ final class StringExtensionTests: XCTestCase {
}
}
+ func testSha256() {
+ let string = "Hello, World! This is a test string."
+ let hash = string.sha256
+ let expected = "3c2b805ab0038afb0629e1d598ae73e0caabb69de03e96762977d34e8ba428bf"
+ let expectedSHA256 = SHA256.hash(data: Data(string.utf8)).map { String(format: "%02hhx", $0) }.joined()
+ XCTAssertEqual(hash, expected)
+ XCTAssertEqual(hash, expectedSHA256)
+ }
+
}
diff --git a/Tests/DDGSyncTests/Mocks/Mocks.swift b/Tests/DDGSyncTests/Mocks/Mocks.swift
index 013123c6d..4745f155c 100644
--- a/Tests/DDGSyncTests/Mocks/Mocks.swift
+++ b/Tests/DDGSyncTests/Mocks/Mocks.swift
@@ -169,18 +169,26 @@ class MockPrivacyConfiguration: PrivacyConfiguration {
return .enabled
}
- func isSubfeatureEnabled(
- _ subfeature: any PrivacySubfeature,
- versionProvider: AppVersionProvider,
- randomizer: (Range) -> Double
- ) -> Bool {
+ func isSubfeatureEnabled(_ subfeature: any BrowserServicesKit.PrivacySubfeature, versionProvider: BrowserServicesKit.AppVersionProvider, randomizer: (Range) -> Double) -> Bool {
true
}
- func stateFor(_ subfeature: any PrivacySubfeature, versionProvider: AppVersionProvider, randomizer: (Range) -> Double) -> PrivacyConfigurationFeatureState {
+ func stateFor(_ subfeature: any BrowserServicesKit.PrivacySubfeature, versionProvider: BrowserServicesKit.AppVersionProvider, randomizer: (Range) -> Double) -> BrowserServicesKit.PrivacyConfigurationFeatureState {
return .enabled
}
+ func stateFor(subfeatureID: BrowserServicesKit.SubfeatureID, parentFeatureID: BrowserServicesKit.ParentFeatureID, versionProvider: BrowserServicesKit.AppVersionProvider, randomizer: (Range) -> Double) -> BrowserServicesKit.PrivacyConfigurationFeatureState {
+ return .enabled
+ }
+
+ func cohorts(for subfeature: any BrowserServicesKit.PrivacySubfeature) -> [BrowserServicesKit.PrivacyConfigurationData.Cohort]? {
+ return nil
+ }
+
+ func cohorts(subfeatureID: BrowserServicesKit.SubfeatureID, parentFeatureID: BrowserServicesKit.ParentFeatureID) -> [BrowserServicesKit.PrivacyConfigurationData.Cohort]? {
+ return nil
+ }
+
var identifier: String = "abcd"
var version: String? = "123456789"
var userUnprotectedDomains: [String] = []
diff --git a/Tests/MaliciousSiteProtectionTests/MaliciousSiteDetectorTests.swift b/Tests/MaliciousSiteProtectionTests/MaliciousSiteDetectorTests.swift
new file mode 100644
index 000000000..f6b0de23a
--- /dev/null
+++ b/Tests/MaliciousSiteProtectionTests/MaliciousSiteDetectorTests.swift
@@ -0,0 +1,103 @@
+//
+// MaliciousSiteDetectorTests.swift
+//
+// Copyright © 2024 DuckDuckGo. All rights reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//
+import Foundation
+import XCTest
+
+@testable import MaliciousSiteProtection
+
+class MaliciousSiteDetectorTests: XCTestCase {
+
+ private var mockAPIClient: MockMaliciousSiteProtectionAPIClient!
+ private var mockDataManager: MockMaliciousSiteProtectionDataManager!
+ private var mockEventMapping: MockEventMapping!
+ private var detector: MaliciousSiteDetector!
+
+ override func setUp() async throws {
+ mockAPIClient = MockMaliciousSiteProtectionAPIClient()
+ mockDataManager = MockMaliciousSiteProtectionDataManager()
+ mockEventMapping = MockEventMapping()
+ detector = MaliciousSiteDetector(apiClient: mockAPIClient, dataManager: mockDataManager, eventMapping: mockEventMapping)
+ }
+
+ override func tearDown() async throws {
+ mockAPIClient = nil
+ mockDataManager = nil
+ mockEventMapping = nil
+ detector = nil
+ }
+
+ func testIsMaliciousWithLocalFilterHit() async {
+ let filter = Filter(hash: "255a8a793097aeea1f06a19c08cde28db0eb34c660c6e4e7480c9525d034b16d", regex: ".*malicious.*")
+ await mockDataManager.store(FilterDictionary(revision: 0, items: [filter]), for: .filterSet(threatKind: .phishing))
+ await mockDataManager.store(HashPrefixSet(revision: 0, items: ["255a8a79"]), for: .hashPrefixes(threatKind: .phishing))
+
+ let url = URL(string: "https://malicious.com/")!
+
+ let result = await detector.evaluate(url)
+
+ XCTAssertEqual(result, .phishing)
+ }
+
+ func testIsMaliciousWithApiMatch() async {
+ await mockDataManager.store(FilterDictionary(revision: 0, items: []), for: .filterSet(threatKind: .phishing))
+ await mockDataManager.store(HashPrefixSet(revision: 0, items: ["a379a6f6"]), for: .hashPrefixes(threatKind: .phishing))
+
+ let url = URL(string: "https://example.com/mal")!
+
+ let result = await detector.evaluate(url)
+
+ XCTAssertEqual(result, .phishing)
+ }
+
+ func testIsMaliciousWithHashPrefixMatch() async {
+ let filter = Filter(hash: "notamatch", regex: ".*malicious.*")
+ await mockDataManager.store(FilterDictionary(revision: 0, items: [filter]), for: .filterSet(threatKind: .phishing))
+ await mockDataManager.store(HashPrefixSet(revision: 0, items: ["4c64eb24" /* matches safe.com */]), for: .hashPrefixes(threatKind: .phishing))
+
+ let url = URL(string: "https://safe.com")!
+
+ let result = await detector.evaluate(url)
+
+ XCTAssertNil(result)
+ }
+
+ func testIsMaliciousWithFullHashMatch() async {
+ // 4c64eb2468bcd3e113b37167e6b819aeccf550f974a6082ef17fb74ca68e823b
+ let filter = Filter(hash: "4c64eb2468bcd3e113b37167e6b819aeccf550f974a6082ef17fb74ca68e823b", regex: "https://safe.com/maliciousURI")
+ await mockDataManager.store(FilterDictionary(revision: 0, items: [filter]), for: .filterSet(threatKind: .phishing))
+ await mockDataManager.store(HashPrefixSet(revision: 0, items: ["4c64eb24"]), for: .hashPrefixes(threatKind: .phishing))
+
+ let url = URL(string: "https://safe.com")!
+
+ let result = await detector.evaluate(url)
+
+ XCTAssertNil(result)
+ }
+
+ func testIsMaliciousWithNoHashPrefixMatch() async {
+ let filter = Filter(hash: "testHash", regex: ".*malicious.*")
+ await mockDataManager.store(FilterDictionary(revision: 0, items: [filter]), for: .filterSet(threatKind: .phishing))
+ await mockDataManager.store(HashPrefixSet(revision: 0, items: ["testPrefix"]), for: .hashPrefixes(threatKind: .phishing))
+
+ let url = URL(string: "https://safe.com")!
+
+ let result = await detector.evaluate(url)
+
+ XCTAssertNil(result)
+ }
+}
diff --git a/Tests/MaliciousSiteProtectionTests/MaliciousSiteProtectionAPIClientTests.swift b/Tests/MaliciousSiteProtectionTests/MaliciousSiteProtectionAPIClientTests.swift
new file mode 100644
index 000000000..fcea80939
--- /dev/null
+++ b/Tests/MaliciousSiteProtectionTests/MaliciousSiteProtectionAPIClientTests.swift
@@ -0,0 +1,143 @@
+//
+// MaliciousSiteProtectionAPIClientTests.swift
+//
+// Copyright © 2024 DuckDuckGo. All rights reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//
+import Foundation
+import Networking
+import TestUtils
+import XCTest
+
+@testable import MaliciousSiteProtection
+
+final class MaliciousSiteProtectionAPIClientTests: XCTestCase {
+
+ var mockService: MockAPIService!
+ var client: MaliciousSiteProtection.APIClient!
+
+ override func setUp() {
+ super.setUp()
+ mockService = MockAPIService()
+ client = .init(environment: MaliciousSiteDetector.APIEnvironment.staging, service: mockService)
+ }
+
+ override func tearDown() {
+ mockService = nil
+ client = nil
+ super.tearDown()
+ }
+
+ func testWhenPhishingFilterSetRequestedAndSucceeds_ChangeSetIsReturned() async throws {
+ // Given
+ let insertFilter = Filter(hash: "a379a6f6eeafb9a55e378c118034e2751e682fab9f2d30ab13d2125586ce1947", regex: ".")
+ let deleteFilter = Filter(hash: "6a929cd0b3ba4677eaedf1b2bdaf3ff89281cca94f688c83103bc9a676aea46d", regex: "(?i)^https?\\:\\/\\/[\\w\\-\\.]+(?:\\:(?:80|443))?")
+ let expectedResponse = APIClient.Response.FiltersChangeSet(insert: [insertFilter], delete: [deleteFilter], revision: 666, replace: false)
+ mockService.requestHandler = { [unowned self] in
+ XCTAssertEqual($0.urlRequest.url, client.environment.url(for: .filterSet(.init(threatKind: .phishing, revision: 666))))
+ let data = try? JSONEncoder().encode(expectedResponse)
+ let response = HTTPURLResponse(url: $0.urlRequest.url!, statusCode: 200, httpVersion: nil, headerFields: nil)!
+ return .success(.init(data: data, httpResponse: response))
+ }
+
+ // When
+ let response = try await client.filtersChangeSet(for: .phishing, revision: 666)
+
+ // Then
+ XCTAssertEqual(response, expectedResponse)
+ }
+
+ func testWhenHashPrefixesRequestedAndSucceeds_ChangeSetIsReturned() async throws {
+ // Given
+ let expectedResponse = APIClient.Response.HashPrefixesChangeSet(insert: ["abc"], delete: ["def"], revision: 1, replace: false)
+ mockService.requestHandler = { [unowned self] in
+ XCTAssertEqual($0.urlRequest.url, client.environment.url(for: .hashPrefixSet(.init(threatKind: .phishing, revision: 1))))
+ let data = try? JSONEncoder().encode(expectedResponse)
+ let response = HTTPURLResponse(url: $0.urlRequest.url!, statusCode: 200, httpVersion: nil, headerFields: nil)!
+ return .success(.init(data: data, httpResponse: response))
+ }
+
+ // When
+ let response = try await client.hashPrefixesChangeSet(for: .phishing, revision: 1)
+
+ // Then
+ XCTAssertEqual(response, expectedResponse)
+ }
+
+ func testWhenMatchesRequestedAndSucceeds_MatchesAreReturned() async throws {
+ // Given
+ let expectedResponse = APIClient.Response.Matches(matches: [Match(hostname: "example.com", url: "https://example.com/test", regex: ".", hash: "a379a6f6eeafb9a55e378c118034e2751e682fab9f2d30ab13d2125586ce1947", category: nil)])
+ mockService.requestHandler = { [unowned self] in
+ XCTAssertEqual($0.urlRequest.url, client.environment.url(for: .matches(.init(hashPrefix: "abc"))))
+ let data = try? JSONEncoder().encode(expectedResponse)
+ let response = HTTPURLResponse(url: $0.urlRequest.url!, statusCode: 200, httpVersion: nil, headerFields: nil)!
+ return .success(.init(data: data, httpResponse: response))
+ }
+
+ // When
+ let response = try await client.matches(forHashPrefix: "abc")
+
+ // Then
+ XCTAssertEqual(response.matches, expectedResponse.matches)
+ }
+
+ func testWhenHashPrefixesRequestFails_ErrorThrown() async throws {
+ // Given
+ let invalidRevision = -1
+ mockService.requestHandler = {
+ // Simulate a failure or invalid request
+ let response = HTTPURLResponse(url: $0.urlRequest.url!, statusCode: 400, httpVersion: nil, headerFields: nil)!
+ return .success(.init(data: nil, httpResponse: response))
+ }
+
+ do {
+ let response = try await client.hashPrefixesChangeSet(for: .phishing, revision: invalidRevision)
+ XCTFail("Unexpected \(response) expected throw")
+ } catch {
+ }
+ }
+
+ func testWhenFilterSetRequestFails_ErrorThrown() async throws {
+ // Given
+ let invalidRevision = -1
+ mockService.requestHandler = {
+ // Simulate a failure or invalid request
+ let response = HTTPURLResponse(url: $0.urlRequest.url!, statusCode: 400, httpVersion: nil, headerFields: nil)!
+ return .success(.init(data: nil, httpResponse: response))
+ }
+
+ do {
+ let response = try await client.hashPrefixesChangeSet(for: .phishing, revision: invalidRevision)
+ XCTFail("Unexpected \(response) expected throw")
+ } catch {
+ }
+ }
+
+ func testWhenMatchesRequestFails_ErrorThrown() async throws {
+ // Given
+ let invalidHashPrefix = ""
+ mockService.requestHandler = {
+ // Simulate a failure or invalid request
+ let response = HTTPURLResponse(url: $0.urlRequest.url!, statusCode: 400, httpVersion: nil, headerFields: nil)!
+ return .success(.init(data: nil, httpResponse: response))
+ }
+
+ do {
+ let response = try await client.matches(forHashPrefix: invalidHashPrefix)
+ XCTFail("Unexpected \(response) expected throw")
+ } catch {
+ }
+ }
+
+}
diff --git a/Tests/MaliciousSiteProtectionTests/MaliciousSiteProtectionDataManagerTests.swift b/Tests/MaliciousSiteProtectionTests/MaliciousSiteProtectionDataManagerTests.swift
new file mode 100644
index 000000000..5164f78d3
--- /dev/null
+++ b/Tests/MaliciousSiteProtectionTests/MaliciousSiteProtectionDataManagerTests.swift
@@ -0,0 +1,250 @@
+//
+// MaliciousSiteProtectionDataManagerTests.swift
+//
+// Copyright © 2024 DuckDuckGo. All rights reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//
+
+import Foundation
+import XCTest
+
+@testable import MaliciousSiteProtection
+
+class MaliciousSiteProtectionDataManagerTests: XCTestCase {
+ var embeddedDataProvider: MockMaliciousSiteProtectionEmbeddedDataProvider!
+ enum Constants {
+ static let hashPrefixesFileName = "phishingHashPrefixes.json"
+ static let filterSetFileName = "phishingFilterSet.json"
+ }
+ let datasetFiles: [String] = [Constants.hashPrefixesFileName, Constants.filterSetFileName]
+ var dataManager: MaliciousSiteProtection.DataManager!
+ var fileStore: MaliciousSiteProtection.FileStoring!
+
+ override func setUp() async throws {
+ embeddedDataProvider = MockMaliciousSiteProtectionEmbeddedDataProvider()
+ fileStore = MockMaliciousSiteProtectionFileStore()
+ setUpDataManager()
+ }
+
+ func setUpDataManager() {
+ dataManager = MaliciousSiteProtection.DataManager(fileStore: fileStore, embeddedDataProvider: embeddedDataProvider, fileNameProvider: { dataType in
+ switch dataType {
+ case .filterSet: Constants.filterSetFileName
+ case .hashPrefixSet: Constants.hashPrefixesFileName
+ }
+ })
+ }
+
+ override func tearDown() async throws {
+ embeddedDataProvider = nil
+ dataManager = nil
+ }
+
+ func clearDatasets() {
+ for fileName in datasetFiles {
+ let emptyData = Data()
+ fileStore.write(data: emptyData, to: fileName)
+ }
+ }
+
+ func testWhenNoDataSavedThenProviderDataReturned() async {
+ clearDatasets()
+ let expectedFilterSet = Set([Filter(hash: "some", regex: "some")])
+ let expectedFilterDict = FilterDictionary(revision: 65, items: expectedFilterSet)
+ let expectedHashPrefix = Set(["sassa"])
+ embeddedDataProvider.filterSet = expectedFilterSet
+ embeddedDataProvider.hashPrefixes = expectedHashPrefix
+
+ let actualFilterSet = await dataManager.dataSet(for: .filterSet(threatKind: .phishing))
+ let actualHashPrefix = await dataManager.dataSet(for: .hashPrefixes(threatKind: .phishing))
+
+ XCTAssertEqual(actualFilterSet, expectedFilterDict)
+ XCTAssertEqual(actualHashPrefix.set, expectedHashPrefix)
+ }
+
+ func testWhenEmbeddedRevisionNewerThanOnDisk_ThenLoadEmbedded() async {
+ let encoder = JSONEncoder()
+ // On Disk Data Setup
+ let onDiskFilterSet = Set([Filter(hash: "other", regex: "other")])
+ let filterSetData = try! encoder.encode(Array(onDiskFilterSet))
+ let onDiskHashPrefix = Set(["faffa"])
+ let hashPrefixData = try! encoder.encode(Array(onDiskHashPrefix))
+ fileStore.write(data: filterSetData, to: Constants.filterSetFileName)
+ fileStore.write(data: hashPrefixData, to: Constants.hashPrefixesFileName)
+
+ // Embedded Data Setup
+ embeddedDataProvider.embeddedRevision = 5
+ let embeddedFilterSet = Set([Filter(hash: "some", regex: "some")])
+ let embeddedFilterDict = FilterDictionary(revision: 5, items: embeddedFilterSet)
+ let embeddedHashPrefix = Set(["sassa"])
+ embeddedDataProvider.filterSet = embeddedFilterSet
+ embeddedDataProvider.hashPrefixes = embeddedHashPrefix
+
+ let actualFilterSet = await dataManager.dataSet(for: .filterSet(threatKind: .phishing))
+ let actualHashPrefix = await dataManager.dataSet(for: .hashPrefixes(threatKind: .phishing))
+ let actualFilterSetRevision = actualFilterSet.revision
+ let actualHashPrefixRevision = actualFilterSet.revision
+
+ XCTAssertEqual(actualFilterSet, embeddedFilterDict)
+ XCTAssertEqual(actualHashPrefix.set, embeddedHashPrefix)
+ XCTAssertEqual(actualFilterSetRevision, 5)
+ XCTAssertEqual(actualHashPrefixRevision, 5)
+ }
+
+ func testWhenEmbeddedRevisionOlderThanOnDisk_ThenDontLoadEmbedded() async {
+ // On Disk Data Setup
+ let onDiskFilterDict = FilterDictionary(revision: 6, items: [Filter(hash: "other", regex: "other")])
+ let filterSetData = try! JSONEncoder().encode(onDiskFilterDict)
+ let onDiskHashPrefix = HashPrefixSet(revision: 6, items: ["faffa"])
+ let hashPrefixData = try! JSONEncoder().encode(onDiskHashPrefix)
+ fileStore.write(data: filterSetData, to: Constants.filterSetFileName)
+ fileStore.write(data: hashPrefixData, to: Constants.hashPrefixesFileName)
+
+ // Embedded Data Setup
+ embeddedDataProvider.embeddedRevision = 1
+ let embeddedFilterSet = Set([Filter(hash: "some", regex: "some")])
+ let embeddedHashPrefix = Set(["sassa"])
+ embeddedDataProvider.filterSet = embeddedFilterSet
+ embeddedDataProvider.hashPrefixes = embeddedHashPrefix
+
+ let actualFilterSet = await dataManager.dataSet(for: .filterSet(threatKind: .phishing))
+ let actualHashPrefix = await dataManager.dataSet(for: .hashPrefixes(threatKind: .phishing))
+ let actualFilterSetRevision = actualFilterSet.revision
+ let actualHashPrefixRevision = actualFilterSet.revision
+
+ XCTAssertEqual(actualFilterSet, onDiskFilterDict)
+ XCTAssertEqual(actualHashPrefix, onDiskHashPrefix)
+ XCTAssertEqual(actualFilterSetRevision, 6)
+ XCTAssertEqual(actualHashPrefixRevision, 6)
+ }
+
+ func testWhenStoredDataIsMalformed_ThenEmbeddedDataIsLoaded() async {
+ // On Disk Data Setup
+ fileStore.write(data: "fake".utf8data, to: Constants.filterSetFileName)
+ fileStore.write(data: "fake".utf8data, to: Constants.hashPrefixesFileName)
+
+ // Embedded Data Setup
+ embeddedDataProvider.embeddedRevision = 1
+ let embeddedFilterSet = Set([Filter(hash: "some", regex: "some")])
+ let embeddedFilterDict = FilterDictionary(revision: 1, items: embeddedFilterSet)
+ let embeddedHashPrefix = Set(["sassa"])
+ embeddedDataProvider.filterSet = embeddedFilterSet
+ embeddedDataProvider.hashPrefixes = embeddedHashPrefix
+
+ let actualFilterSet = await dataManager.dataSet(for: .filterSet(threatKind: .phishing))
+ let actualHashPrefix = await dataManager.dataSet(for: .hashPrefixes(threatKind: .phishing))
+ let actualFilterSetRevision = actualFilterSet.revision
+ let actualHashPrefixRevision = actualFilterSet.revision
+
+ XCTAssertEqual(actualFilterSet, embeddedFilterDict)
+ XCTAssertEqual(actualHashPrefix.set, embeddedHashPrefix)
+ XCTAssertEqual(actualFilterSetRevision, 1)
+ XCTAssertEqual(actualHashPrefixRevision, 1)
+ }
+
+ func testWriteAndLoadData() async {
+ // Get and write data
+ let expectedHashPrefixes = Set(["aabb"])
+ let expectedFilterSet = Set([Filter(hash: "dummyhash", regex: "dummyregex")])
+ let expectedRevision = 65
+
+ await dataManager.store(HashPrefixSet(revision: expectedRevision, items: expectedHashPrefixes), for: .hashPrefixes(threatKind: .phishing))
+ await dataManager.store(FilterDictionary(revision: expectedRevision, items: expectedFilterSet), for: .filterSet(threatKind: .phishing))
+
+ let actualFilterSet = await dataManager.dataSet(for: .filterSet(threatKind: .phishing))
+ let actualHashPrefix = await dataManager.dataSet(for: .hashPrefixes(threatKind: .phishing))
+ let actualFilterSetRevision = actualFilterSet.revision
+ let actualHashPrefixRevision = actualFilterSet.revision
+
+ XCTAssertEqual(actualFilterSet, FilterDictionary(revision: expectedRevision, items: expectedFilterSet))
+ XCTAssertEqual(actualHashPrefix.set, expectedHashPrefixes)
+ XCTAssertEqual(actualFilterSetRevision, 65)
+ XCTAssertEqual(actualHashPrefixRevision, 65)
+
+ // Test reloading data
+ setUpDataManager()
+
+ let reloadedFilterSet = await dataManager.dataSet(for: .filterSet(threatKind: .phishing))
+ let reloadedHashPrefix = await dataManager.dataSet(for: .hashPrefixes(threatKind: .phishing))
+ let reloadedFilterSetRevision = actualFilterSet.revision
+ let reloadedHashPrefixRevision = actualFilterSet.revision
+
+ XCTAssertEqual(reloadedFilterSet, FilterDictionary(revision: expectedRevision, items: expectedFilterSet))
+ XCTAssertEqual(reloadedHashPrefix.set, expectedHashPrefixes)
+ XCTAssertEqual(reloadedFilterSetRevision, 65)
+ XCTAssertEqual(reloadedHashPrefixRevision, 65)
+ }
+
+ func testLazyLoadingDoesNotReturnStaleData() async {
+ clearDatasets()
+
+ // Set up initial data
+ let initialFilterSet = Set([Filter(hash: "initial", regex: "initial")])
+ let initialHashPrefixes = Set(["initialPrefix"])
+ embeddedDataProvider.filterSet = initialFilterSet
+ embeddedDataProvider.hashPrefixes = initialHashPrefixes
+
+ // Access the lazy-loaded properties to trigger loading
+ let loadedFilterSet = await dataManager.dataSet(for: .filterSet(threatKind: .phishing))
+ let loadedHashPrefixes = await dataManager.dataSet(for: .hashPrefixes(threatKind: .phishing))
+
+ // Validate loaded data matches initial data
+ XCTAssertEqual(loadedFilterSet, FilterDictionary(revision: 65, items: initialFilterSet))
+ XCTAssertEqual(loadedHashPrefixes.set, initialHashPrefixes)
+
+ // Update in-memory data
+ let updatedFilterSet = Set([Filter(hash: "updated", regex: "updated")])
+ let updatedHashPrefixes = Set(["updatedPrefix"])
+ await dataManager.store(HashPrefixSet(revision: 1, items: updatedHashPrefixes), for: .hashPrefixes(threatKind: .phishing))
+ await dataManager.store(FilterDictionary(revision: 1, items: updatedFilterSet), for: .filterSet(threatKind: .phishing))
+
+ let actualFilterSet = await dataManager.dataSet(for: .filterSet(threatKind: .phishing))
+ let actualHashPrefix = await dataManager.dataSet(for: .hashPrefixes(threatKind: .phishing))
+ let actualFilterSetRevision = actualFilterSet.revision
+ let actualHashPrefixRevision = actualFilterSet.revision
+
+ XCTAssertEqual(actualFilterSet, FilterDictionary(revision: 1, items: updatedFilterSet))
+ XCTAssertEqual(actualHashPrefix.set, updatedHashPrefixes)
+ XCTAssertEqual(actualFilterSetRevision, 1)
+ XCTAssertEqual(actualHashPrefixRevision, 1)
+
+ // Test reloading data – embedded data should be returned as its revision is greater
+ setUpDataManager()
+
+ let reloadedFilterSet = await dataManager.dataSet(for: .filterSet(threatKind: .phishing))
+ let reloadedHashPrefix = await dataManager.dataSet(for: .hashPrefixes(threatKind: .phishing))
+ let reloadedFilterSetRevision = actualFilterSet.revision
+ let reloadedHashPrefixRevision = actualFilterSet.revision
+
+ XCTAssertEqual(reloadedFilterSet, FilterDictionary(revision: 65, items: initialFilterSet))
+ XCTAssertEqual(reloadedHashPrefix.set, initialHashPrefixes)
+ XCTAssertEqual(reloadedFilterSetRevision, 1)
+ XCTAssertEqual(reloadedHashPrefixRevision, 1)
+ }
+
+}
+
+class MockMaliciousSiteProtectionFileStore: MaliciousSiteProtection.FileStoring {
+
+ private var data: [String: Data] = [:]
+
+ func write(data: Data, to filename: String) -> Bool {
+ self.data[filename] = data
+ return true
+ }
+
+ func read(from filename: String) -> Data? {
+ return data[filename]
+ }
+}
diff --git a/Tests/MaliciousSiteProtectionTests/MaliciousSiteProtectionEmbeddedDataProviderTest.swift b/Tests/MaliciousSiteProtectionTests/MaliciousSiteProtectionEmbeddedDataProviderTest.swift
new file mode 100644
index 000000000..1e3e0df40
--- /dev/null
+++ b/Tests/MaliciousSiteProtectionTests/MaliciousSiteProtectionEmbeddedDataProviderTest.swift
@@ -0,0 +1,62 @@
+//
+// MaliciousSiteProtectionEmbeddedDataProviderTest.swift
+//
+// Copyright © 2024 DuckDuckGo. All rights reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//
+import Foundation
+import XCTest
+
+@testable import MaliciousSiteProtection
+
+class MaliciousSiteProtectionEmbeddedDataProviderTest: XCTestCase {
+
+ struct TestEmbeddedDataProvider: MaliciousSiteProtection.EmbeddedDataProviding {
+ func revision(for dataType: MaliciousSiteProtection.DataManager.StoredDataType) -> Int {
+ 0
+ }
+
+ func url(for dataType: MaliciousSiteProtection.DataManager.StoredDataType) -> URL {
+ switch dataType {
+ case .filterSet(let key):
+ Bundle.module.url(forResource: "\(key.threatKind)FilterSet", withExtension: "json")!
+ case .hashPrefixSet(let key):
+ Bundle.module.url(forResource: "\(key.threatKind)HashPrefixes", withExtension: "json")!
+ }
+ }
+
+ func hash(for dataType: MaliciousSiteProtection.DataManager.StoredDataType) -> String {
+ switch dataType {
+ case .filterSet(let key):
+ switch key.threatKind {
+ case .phishing:
+ "4fd2868a4f264501ec175ab866504a2a96c8d21a3b5195b405a4a83b51eae504"
+ }
+ case .hashPrefixSet(let key):
+ switch key.threatKind {
+ case .phishing:
+ "21b047a9950fcaf86034a6b16181e18815cb8d276386d85c8977ca8c5f8aa05f"
+ }
+ }
+ }
+ }
+
+ func testDataProviderLoadsJSON() {
+ let dataProvider = TestEmbeddedDataProvider()
+ let expectedFilter = Filter(hash: "e4753ddad954dafd4ff4ef67f82b3c1a2db6ef4a51bda43513260170e558bd13", regex: "(?i)^https?\\:\\/\\/privacy-test-pages\\.site(?:\\:(?:80|443))?\\/security\\/badware\\/phishing\\.html$")
+ XCTAssertTrue(dataProvider.loadDataSet(for: .filterSet(threatKind: .phishing)).contains(expectedFilter))
+ XCTAssertTrue(dataProvider.loadDataSet(for: .hashPrefixes(threatKind: .phishing)).contains("012db806"))
+ }
+
+}
diff --git a/Tests/PhishingDetectionTests/PhishingDetectionURLTests.swift b/Tests/MaliciousSiteProtectionTests/MaliciousSiteProtectionURLTests.swift
similarity index 92%
rename from Tests/PhishingDetectionTests/PhishingDetectionURLTests.swift
rename to Tests/MaliciousSiteProtectionTests/MaliciousSiteProtectionURLTests.swift
index ea0576369..8df462b3e 100644
--- a/Tests/PhishingDetectionTests/PhishingDetectionURLTests.swift
+++ b/Tests/MaliciousSiteProtectionTests/MaliciousSiteProtectionURLTests.swift
@@ -1,5 +1,5 @@
//
-// PhishingDetectionURLTests.swift
+// MaliciousSiteProtectionURLTests.swift
//
// Copyright © 2024 DuckDuckGo. All rights reserved.
//
@@ -18,9 +18,10 @@
import Foundation
import XCTest
-@testable import PhishingDetection
-class PhishingDetectionURLTests: XCTestCase {
+@testable import MaliciousSiteProtection
+
+class MaliciousSiteProtectionURLTests: XCTestCase {
let testURLs = [
"http://www.example.com/security/badware/phishing.html#frags",
diff --git a/Tests/MaliciousSiteProtectionTests/MaliciousSiteProtectionUpdateManagerTests.swift b/Tests/MaliciousSiteProtectionTests/MaliciousSiteProtectionUpdateManagerTests.swift
new file mode 100644
index 000000000..8d46d5cf7
--- /dev/null
+++ b/Tests/MaliciousSiteProtectionTests/MaliciousSiteProtectionUpdateManagerTests.swift
@@ -0,0 +1,392 @@
+//
+// MaliciousSiteProtectionUpdateManagerTests.swift
+//
+// Copyright © 2024 DuckDuckGo. All rights reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//
+
+import Clocks
+import Common
+import Foundation
+import XCTest
+
+@testable import MaliciousSiteProtection
+
+class MaliciousSiteProtectionUpdateManagerTests: XCTestCase {
+
+ var updateManager: MaliciousSiteProtection.UpdateManager!
+ var dataManager: MockMaliciousSiteProtectionDataManager!
+ var apiClient: MaliciousSiteProtection.APIClient.Mockable!
+ var updateIntervalProvider: UpdateManager.UpdateIntervalProvider!
+ var clock: TestClock!
+ var willSleep: ((TimeInterval) -> Void)?
+ var updateTask: Task?
+
+ override func setUp() async throws {
+ apiClient = MockMaliciousSiteProtectionAPIClient()
+ dataManager = MockMaliciousSiteProtectionDataManager()
+ clock = TestClock()
+
+ let clockSleeper = Sleeper(clock: clock)
+ let reportingSleeper = Sleeper {
+ self.willSleep?($0)
+ try await clockSleeper.sleep(for: $0)
+ }
+
+ updateManager = MaliciousSiteProtection.UpdateManager(apiClient: apiClient, dataManager: dataManager, sleeper: reportingSleeper, updateIntervalProvider: { self.updateIntervalProvider($0) })
+ }
+
+ override func tearDown() async throws {
+ updateManager = nil
+ dataManager = nil
+ apiClient = nil
+ updateIntervalProvider = nil
+ updateTask?.cancel()
+ }
+
+ func testUpdateHashPrefixes() async {
+ await updateManager.updateData(for: .hashPrefixes(threatKind: .phishing))
+ let dataSet = await dataManager.dataSet(for: .hashPrefixes(threatKind: .phishing))
+ XCTAssertEqual(dataSet, HashPrefixSet(revision: 1, items: [
+ "aa00bb11",
+ "bb00cc11",
+ "cc00dd11",
+ "dd00ee11",
+ "a379a6f6"
+ ]))
+ }
+
+ func testUpdateFilterSet() async {
+ await updateManager.updateData(for: .filterSet(threatKind: .phishing))
+ let dataSet = await dataManager.dataSet(for: .filterSet(threatKind: .phishing))
+ XCTAssertEqual(dataSet, FilterDictionary(revision: 1, items: [
+ Filter(hash: "testhash1", regex: ".*example.*"),
+ Filter(hash: "testhash2", regex: ".*test.*")
+ ]))
+ }
+
+ func testRevision1AddsAndDeletesData() async {
+ let expectedFilterSet: Set = [
+ Filter(hash: "testhash2", regex: ".*test.*"),
+ Filter(hash: "testhash3", regex: ".*test.*")
+ ]
+ let expectedHashPrefixes: Set = [
+ "aa00bb11",
+ "bb00cc11",
+ "a379a6f6",
+ "93e2435e"
+ ]
+
+ // revision 0 -> 1
+ await updateManager.updateData(for: .filterSet(threatKind: .phishing))
+ await updateManager.updateData(for: .hashPrefixes(threatKind: .phishing))
+
+ // revision 1 -> 2
+ await updateManager.updateData(for: .filterSet(threatKind: .phishing))
+ await updateManager.updateData(for: .hashPrefixes(threatKind: .phishing))
+
+ let hashPrefixes = await dataManager.dataSet(for: .hashPrefixes(threatKind: .phishing))
+ let filterSet = await dataManager.dataSet(for: .filterSet(threatKind: .phishing))
+
+ XCTAssertEqual(hashPrefixes, HashPrefixSet(revision: 2, items: expectedHashPrefixes), "Hash prefixes should match the expected set after update.")
+ XCTAssertEqual(filterSet, FilterDictionary(revision: 2, items: expectedFilterSet), "Filter set should match the expected set after update.")
+ }
+
+ func testRevision2AddsAndDeletesData() async {
+ let expectedFilterSet: Set = [
+ Filter(hash: "testhash4", regex: ".*test.*"),
+ Filter(hash: "testhash2", regex: ".*test1.*"),
+ Filter(hash: "testhash1", regex: ".*example.*"),
+ Filter(hash: "testhash3", regex: ".*test3.*"),
+ ]
+ let expectedHashPrefixes: Set = [
+ "aa00bb11",
+ "a379a6f6",
+ "c0be0d0a6",
+ "dd00ee11",
+ "cc00dd11"
+ ]
+
+ // Save revision and update the filter set and hash prefixes
+ await dataManager.store(FilterDictionary(revision: 2, items: [
+ Filter(hash: "testhash1", regex: ".*example.*"),
+ Filter(hash: "testhash2", regex: ".*test.*"),
+ Filter(hash: "testhash2", regex: ".*test1.*"),
+ Filter(hash: "testhash3", regex: ".*test3.*"),
+ ]), for: .filterSet(threatKind: .phishing))
+ await dataManager.store(HashPrefixSet(revision: 2, items: [
+ "aa00bb11",
+ "bb00cc11",
+ "cc00dd11",
+ "dd00ee11",
+ "a379a6f6"
+ ]), for: .hashPrefixes(threatKind: .phishing))
+
+ await updateManager.updateData(for: .filterSet(threatKind: .phishing))
+ await updateManager.updateData(for: .hashPrefixes(threatKind: .phishing))
+
+ let hashPrefixes = await dataManager.dataSet(for: .hashPrefixes(threatKind: .phishing))
+ let filterSet = await dataManager.dataSet(for: .filterSet(threatKind: .phishing))
+
+ XCTAssertEqual(hashPrefixes, HashPrefixSet(revision: 3, items: expectedHashPrefixes), "Hash prefixes should match the expected set after update.")
+ XCTAssertEqual(filterSet, FilterDictionary(revision: 3, items: expectedFilterSet), "Filter set should match the expected set after update.")
+ }
+
+ func testRevision3AddsAndDeletesNothing() async {
+ let expectedFilterSet: Set = []
+ let expectedHashPrefixes: Set = []
+
+ // Save revision and update the filter set and hash prefixes
+ await dataManager.store(FilterDictionary(revision: 3, items: []), for: .filterSet(threatKind: .phishing))
+ await dataManager.store(HashPrefixSet(revision: 3, items: []), for: .hashPrefixes(threatKind: .phishing))
+
+ await updateManager.updateData(for: .filterSet(threatKind: .phishing))
+ await updateManager.updateData(for: .hashPrefixes(threatKind: .phishing))
+
+ let hashPrefixes = await dataManager.dataSet(for: .hashPrefixes(threatKind: .phishing))
+ let filterSet = await dataManager.dataSet(for: .filterSet(threatKind: .phishing))
+
+ XCTAssertEqual(hashPrefixes, HashPrefixSet(revision: 3, items: expectedHashPrefixes), "Hash prefixes should match the expected set after update.")
+ XCTAssertEqual(filterSet, FilterDictionary(revision: 3, items: expectedFilterSet), "Filter set should match the expected set after update.")
+ }
+
+ func testRevision4AddsAndDeletesData() async {
+ let expectedFilterSet: Set = [
+ Filter(hash: "testhash5", regex: ".*test.*")
+ ]
+ let expectedHashPrefixes: Set = [
+ "a379a6f6",
+ ]
+
+ // Save revision and update the filter set and hash prefixes
+ await dataManager.store(FilterDictionary(revision: 4, items: []), for: .filterSet(threatKind: .phishing))
+ await dataManager.store(HashPrefixSet(revision: 4, items: []), for: .hashPrefixes(threatKind: .phishing))
+
+ await updateManager.updateData(for: .filterSet(threatKind: .phishing))
+ await updateManager.updateData(for: .hashPrefixes(threatKind: .phishing))
+
+ let hashPrefixes = await dataManager.dataSet(for: .hashPrefixes(threatKind: .phishing))
+ let filterSet = await dataManager.dataSet(for: .filterSet(threatKind: .phishing))
+
+ XCTAssertEqual(hashPrefixes, HashPrefixSet(revision: 5, items: expectedHashPrefixes), "Hash prefixes should match the expected set after update.")
+ XCTAssertEqual(filterSet, FilterDictionary(revision: 5, items: expectedFilterSet), "Filter set should match the expected set after update.")
+ }
+
+ func testRevision5replacesData() async {
+ let expectedFilterSet: Set = [
+ Filter(hash: "testhash6", regex: ".*test6.*")
+ ]
+ let expectedHashPrefixes: Set = [
+ "aa55aa55"
+ ]
+
+ // Save revision and update the filter set and hash prefixes
+ await dataManager.store(FilterDictionary(revision: 5, items: [
+ Filter(hash: "testhash2", regex: ".*test.*"),
+ Filter(hash: "testhash1", regex: ".*example.*"),
+ Filter(hash: "testhash5", regex: ".*test.*")
+ ]), for: .filterSet(threatKind: .phishing))
+ await dataManager.store(HashPrefixSet(revision: 5, items: [
+ "a379a6f6",
+ "dd00ee11",
+ "cc00dd11",
+ "bb00cc11"
+ ]), for: .hashPrefixes(threatKind: .phishing))
+
+ await updateManager.updateData(for: .filterSet(threatKind: .phishing))
+ await updateManager.updateData(for: .hashPrefixes(threatKind: .phishing))
+
+ let hashPrefixes = await dataManager.dataSet(for: .hashPrefixes(threatKind: .phishing))
+ let filterSet = await dataManager.dataSet(for: .filterSet(threatKind: .phishing))
+
+ XCTAssertEqual(hashPrefixes, HashPrefixSet(revision: 6, items: expectedHashPrefixes), "Hash prefixes should match the expected set after update.")
+ XCTAssertEqual(filterSet, FilterDictionary(revision: 6, items: expectedFilterSet), "Filter set should match the expected set after update.")
+ }
+
+ func testWhenPeriodicUpdatesStart_dataSetsAreUpdated() async throws {
+ self.updateIntervalProvider = { _ in 1 }
+
+ let eHashPrefixesUpdated = expectation(description: "Hash prefixes updated")
+ let c1 = await dataManager.publisher(for: .hashPrefixes(threatKind: .phishing)).dropFirst().sink { data in
+ eHashPrefixesUpdated.fulfill()
+ }
+ let eFilterSetUpdated = expectation(description: "Filter set updated")
+ let c2 = await dataManager.publisher(for: .filterSet(threatKind: .phishing)).dropFirst().sink { data in
+ eFilterSetUpdated.fulfill()
+ }
+
+ updateTask = updateManager.startPeriodicUpdates()
+ await Task.megaYield(count: 10)
+
+ // expect initial update run instantly
+ await fulfillment(of: [eHashPrefixesUpdated, eFilterSetUpdated], timeout: 1)
+
+ withExtendedLifetime((c1, c2)) {}
+ }
+
+ func testWhenPeriodicUpdatesAreEnabled_dataSetsAreUpdatedContinuously() async throws {
+ // Start periodic updates
+ self.updateIntervalProvider = { dataType in
+ switch dataType {
+ case .filterSet: return 2
+ case .hashPrefixSet: return 1
+ }
+ }
+
+ let hashPrefixUpdateExpectations = [
+ XCTestExpectation(description: "Hash prefixes rev.1 update received"),
+ XCTestExpectation(description: "Hash prefixes rev.2 update received"),
+ XCTestExpectation(description: "Hash prefixes rev.3 update received"),
+ ]
+ let filterSetUpdateExpectations = [
+ XCTestExpectation(description: "Filter set rev.1 update received"),
+ XCTestExpectation(description: "Filter set rev.2 update received"),
+ XCTestExpectation(description: "Filter set rev.3 update received"),
+ ]
+ let hashPrefixSleepExpectations = [
+ XCTestExpectation(description: "HP Will Sleep 1"),
+ XCTestExpectation(description: "HP Will Sleep 2"),
+ XCTestExpectation(description: "HP Will Sleep 3"),
+ ]
+ let filterSetSleepExpectations = [
+ XCTestExpectation(description: "FS Will Sleep 1"),
+ XCTestExpectation(description: "FS Will Sleep 2"),
+ XCTestExpectation(description: "FS Will Sleep 3"),
+ ]
+
+ let c1 = await dataManager.publisher(for: .hashPrefixes(threatKind: .phishing)).dropFirst().sink { data in
+ hashPrefixUpdateExpectations[data.revision - 1].fulfill()
+ }
+ let c2 = await dataManager.publisher(for: .filterSet(threatKind: .phishing)).dropFirst().sink { data in
+ filterSetUpdateExpectations[data.revision - 1].fulfill()
+ }
+ var hashPrefixSleepIndex = 0
+ var filterSetSleepIndex = 0
+ self.willSleep = { interval in
+ if interval == 1 {
+ hashPrefixSleepExpectations[safe: hashPrefixSleepIndex]?.fulfill()
+ hashPrefixSleepIndex += 1
+ } else {
+ filterSetSleepExpectations[safe: filterSetSleepIndex]?.fulfill()
+ filterSetSleepIndex += 1
+ }
+ }
+
+ // expect initial hashPrefixes update run instantly
+ updateTask = updateManager.startPeriodicUpdates()
+ await fulfillment(of: [hashPrefixUpdateExpectations[0], hashPrefixSleepExpectations[0], filterSetUpdateExpectations[0], filterSetSleepExpectations[0]], timeout: 1)
+
+ // Advance the clock by 1 seconds
+ await self.clock.advance(by: .seconds(1))
+ // expect to receive v.2 update for hashPrefixes
+ await fulfillment(of: [hashPrefixUpdateExpectations[1], hashPrefixSleepExpectations[1]], timeout: 1)
+
+ // Advance the clock by 1 seconds
+ await self.clock.advance(by: .seconds(1))
+ // expect to receive v.3 update for hashPrefixes and v.2 update for filterSet
+ await fulfillment(of: [hashPrefixUpdateExpectations[2], hashPrefixSleepExpectations[2], filterSetUpdateExpectations[1], filterSetSleepExpectations[1]], timeout: 1) //
+
+ // Advance the clock by 1 seconds
+ await self.clock.advance(by: .seconds(2))
+ // expect to receive v.3 update for filterSet and no update for hashPrefixes (no v.3 updates in the mock)
+ await fulfillment(of: [filterSetUpdateExpectations[2], filterSetSleepExpectations[2]], timeout: 1) //
+
+ withExtendedLifetime((c1, c2)) {}
+ }
+
+ func testWhenPeriodicUpdatesAreDisabled_noDataSetsAreUpdated() async throws {
+ // Start periodic updates
+ self.updateIntervalProvider = { dataType in
+ switch dataType {
+ case .filterSet: return nil // Set update interval to nil for FilterSet
+ case .hashPrefixSet: return 1
+ }
+ }
+
+ let expectations = [
+ XCTestExpectation(description: "Hash prefixes rev.1 update received"),
+ XCTestExpectation(description: "Hash prefixes rev.2 update received"),
+ XCTestExpectation(description: "Hash prefixes rev.3 update received"),
+ ]
+ let c1 = await dataManager.publisher(for: .hashPrefixes(threatKind: .phishing)).dropFirst().sink { data in
+ expectations[data.revision - 1].fulfill()
+ }
+ // data for FilterSet should not be updated
+ let c2 = await dataManager.publisher(for: .filterSet(threatKind: .phishing)).dropFirst().sink { data in
+ XCTFail("Unexpected filter set update received: \(data)")
+ }
+ // synchronize Task threads to advance the Test Clock when the updated Task is sleeping,
+ // otherwise we‘ll eventually advance the clock before the sleep and get hung.
+ var sleepIndex = 0
+ let sleepExpectations = [
+ XCTestExpectation(description: "Will Sleep 1"),
+ XCTestExpectation(description: "Will Sleep 2"),
+ XCTestExpectation(description: "Will Sleep 3"),
+ ]
+ self.willSleep = { _ in
+ sleepExpectations[sleepIndex].fulfill()
+ sleepIndex += 1
+ }
+
+ // expect initial hashPrefixes update run instantly
+ updateTask = updateManager.startPeriodicUpdates()
+ await fulfillment(of: [expectations[0], sleepExpectations[0]], timeout: 1)
+
+ // Advance the clock by 1 seconds
+ await self.clock.advance(by: .seconds(1))
+ // expect to receive v.2 update for hashPrefixes
+ await fulfillment(of: [expectations[1], sleepExpectations[1]], timeout: 1)
+
+ // Advance the clock by 1 seconds
+ await self.clock.advance(by: .seconds(1))
+ // expect to receive v.3 update for hashPrefixes
+ await fulfillment(of: [expectations[2], sleepExpectations[2]], timeout: 1)
+
+ withExtendedLifetime((c1, c2)) {}
+ }
+
+ func testWhenPeriodicUpdatesAreCancelled_noFurtherUpdatesReceived() async throws {
+ // Start periodic updates
+ self.updateIntervalProvider = { _ in 1 }
+ updateTask = updateManager.startPeriodicUpdates()
+
+ // Wait for the initial update
+ try await withTimeout(1) { [self] in
+ for await _ in await dataManager.publisher(for: .filterSet(threatKind: .phishing)).first(where: { $0.revision == 1 }).values {}
+ for await _ in await dataManager.publisher(for: .filterSet(threatKind: .phishing)).first(where: { $0.revision == 1 }).values {}
+ }
+
+ // Cancel the update task
+ updateTask!.cancel()
+
+ // Reset expectations for further updates
+ let c = await dataManager.$store.dropFirst().sink { data in
+ XCTFail("Unexpected data update received: \(data)")
+ }
+
+ // Advance the clock to check for further updates
+ await self.clock.advance(by: .seconds(2))
+ await clock.run()
+ await Task.megaYield(count: 10)
+
+ // Verify that the data sets have not been updated further
+ let hashPrefixes = await dataManager.dataSet(for: .hashPrefixes(threatKind: .phishing))
+ let filterSet = await dataManager.dataSet(for: .filterSet(threatKind: .phishing))
+ XCTAssertEqual(hashPrefixes.revision, 1) // Expecting revision to remain 1
+ XCTAssertEqual(filterSet.revision, 1) // Expecting revision to remain 1
+
+ withExtendedLifetime(c) {}
+ }
+
+}
diff --git a/Tests/PhishingDetectionTests/Mocks/EventMappingMock.swift b/Tests/MaliciousSiteProtectionTests/Mocks/MockEventMapping.swift
similarity index 80%
rename from Tests/PhishingDetectionTests/Mocks/EventMappingMock.swift
rename to Tests/MaliciousSiteProtectionTests/Mocks/MockEventMapping.swift
index 7c736c7e3..1edbb98a2 100644
--- a/Tests/PhishingDetectionTests/Mocks/EventMappingMock.swift
+++ b/Tests/MaliciousSiteProtectionTests/Mocks/MockEventMapping.swift
@@ -1,5 +1,5 @@
//
-// EventMappingMock.swift
+// MockEventMapping.swift
//
// Copyright © 2024 DuckDuckGo. All rights reserved.
//
@@ -15,13 +15,14 @@
// See the License for the specific language governing permissions and
// limitations under the License.
//
-import Foundation
+
import Common
-import PhishingDetection
+import Foundation
+import MaliciousSiteProtection
import PixelKit
-public class MockEventMapping: EventMapping {
- static var events: [PhishingDetectionEvents] = []
+public class MockEventMapping: EventMapping {
+ static var events: [MaliciousSiteProtection.Event] = []
static var clientSideHitParam: String?
static var errorParam: Error?
@@ -39,7 +40,7 @@ public class MockEventMapping: EventMapping {
}
}
- override init(mapping: @escaping EventMapping.Mapping) {
+ override init(mapping: @escaping EventMapping.Mapping) {
fatalError("Use init()")
}
}
diff --git a/Tests/MaliciousSiteProtectionTests/Mocks/MockMaliciousSiteProtectionAPIClient.swift b/Tests/MaliciousSiteProtectionTests/Mocks/MockMaliciousSiteProtectionAPIClient.swift
new file mode 100644
index 000000000..4f2062edd
--- /dev/null
+++ b/Tests/MaliciousSiteProtectionTests/Mocks/MockMaliciousSiteProtectionAPIClient.swift
@@ -0,0 +1,103 @@
+//
+// MockMaliciousSiteProtectionAPIClient.swift
+//
+// Copyright © 2024 DuckDuckGo. All rights reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//
+
+import Foundation
+@testable import MaliciousSiteProtection
+
+class MockMaliciousSiteProtectionAPIClient: MaliciousSiteProtection.APIClient.Mockable {
+ var updateHashPrefixesCalled: ((Int) -> Void)?
+ var updateFilterSetsCalled: ((Int) -> Void)?
+
+ var filterRevisions: [Int: APIClient.Response.FiltersChangeSet] = [
+ 0: .init(insert: [
+ Filter(hash: "testhash1", regex: ".*example.*"),
+ Filter(hash: "testhash2", regex: ".*test.*")
+ ], delete: [], revision: 1, replace: false),
+ 1: .init(insert: [
+ Filter(hash: "testhash3", regex: ".*test.*")
+ ], delete: [
+ Filter(hash: "testhash1", regex: ".*example.*"),
+ ], revision: 2, replace: false),
+ 2: .init(insert: [
+ Filter(hash: "testhash4", regex: ".*test.*")
+ ], delete: [
+ Filter(hash: "testhash2", regex: ".*test.*"),
+ ], revision: 3, replace: false),
+ 4: .init(insert: [
+ Filter(hash: "testhash5", regex: ".*test.*")
+ ], delete: [
+ Filter(hash: "testhash3", regex: ".*test.*"),
+ ], revision: 5, replace: false),
+ 5: .init(insert: [
+ Filter(hash: "testhash6", regex: ".*test6.*")
+ ], delete: [
+ Filter(hash: "testhash3", regex: ".*test.*"),
+ ], revision: 6, replace: true),
+ ]
+
+ private var hashPrefixRevisions: [Int: APIClient.Response.HashPrefixesChangeSet] = [
+ 0: .init(insert: [
+ "aa00bb11",
+ "bb00cc11",
+ "cc00dd11",
+ "dd00ee11",
+ "a379a6f6"
+ ], delete: [], revision: 1, replace: false),
+ 1: .init(insert: ["93e2435e"], delete: [
+ "cc00dd11",
+ "dd00ee11",
+ ], revision: 2, replace: false),
+ 2: .init(insert: ["c0be0d0a6"], delete: [
+ "bb00cc11",
+ ], revision: 3, replace: false),
+ 4: .init(insert: ["a379a6f6"], delete: [
+ "aa00bb11",
+ ], revision: 5, replace: false),
+ 5: .init(insert: ["aa55aa55"], delete: [
+ "ffgghhzz",
+ ], revision: 6, replace: true),
+ ]
+
+ func load(_ requestConfig: Request) async throws -> Request.Response where Request: APIClient.Request {
+ switch requestConfig.requestType {
+ case .hashPrefixSet(let configuration):
+ return _hashPrefixesChangeSet(for: configuration.threatKind, revision: configuration.revision ?? 0) as! Request.Response
+ case .filterSet(let configuration):
+ return _filtersChangeSet(for: configuration.threatKind, revision: configuration.revision ?? 0) as! Request.Response
+ case .matches(let configuration):
+ return _matches(forHashPrefix: configuration.hashPrefix) as! Request.Response
+ }
+ }
+ func _filtersChangeSet(for threatKind: MaliciousSiteProtection.ThreatKind, revision: Int) -> MaliciousSiteProtection.APIClient.Response.FiltersChangeSet {
+ updateFilterSetsCalled?(revision)
+ return filterRevisions[revision] ?? .init(insert: [], delete: [], revision: revision, replace: false)
+ }
+
+ func _hashPrefixesChangeSet(for threatKind: MaliciousSiteProtection.ThreatKind, revision: Int) -> MaliciousSiteProtection.APIClient.Response.HashPrefixesChangeSet {
+ updateHashPrefixesCalled?(revision)
+ return hashPrefixRevisions[revision] ?? .init(insert: [], delete: [], revision: revision, replace: false)
+ }
+
+ func _matches(forHashPrefix hashPrefix: String) -> APIClient.Response.Matches {
+ .init(matches: [
+ Match(hostname: "example.com", url: "https://example.com/mal", regex: ".*", hash: "a379a6f6eeafb9a55e378c118034e2751e682fab9f2d30ab13d2125586ce1947", category: nil),
+ Match(hostname: "test.com", url: "https://test.com/mal", regex: ".*test.*", hash: "aa00bb11aa00cc11bb00cc11", category: nil)
+ ])
+ }
+
+}
diff --git a/Tests/MaliciousSiteProtectionTests/Mocks/MockMaliciousSiteProtectionDataManager.swift b/Tests/MaliciousSiteProtectionTests/Mocks/MockMaliciousSiteProtectionDataManager.swift
new file mode 100644
index 000000000..1a67ad329
--- /dev/null
+++ b/Tests/MaliciousSiteProtectionTests/Mocks/MockMaliciousSiteProtectionDataManager.swift
@@ -0,0 +1,40 @@
+//
+// MockMaliciousSiteProtectionDataManager.swift
+//
+// Copyright © 2024 DuckDuckGo. All rights reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//
+
+import Combine
+import Foundation
+@testable import MaliciousSiteProtection
+
+actor MockMaliciousSiteProtectionDataManager: MaliciousSiteProtection.DataManaging {
+
+ @Published var store = [MaliciousSiteProtection.DataManager.StoredDataType: Any]()
+ func publisher(for key: DataKey) -> AnyPublisher where DataKey: MaliciousSiteProtection.MaliciousSiteDataKey {
+ $store.map { $0[key.dataType] as? DataKey.DataSet ?? .init(revision: 0, items: []) }
+ .removeDuplicates()
+ .eraseToAnyPublisher()
+ }
+
+ public func dataSet(for key: DataKey) -> DataKey.DataSet where DataKey: MaliciousSiteProtection.MaliciousSiteDataKey {
+ return store[key.dataType] as? DataKey.DataSet ?? .init(revision: 0, items: [])
+ }
+
+ func store(_ dataSet: DataKey.DataSet, for key: DataKey) async where DataKey: MaliciousSiteProtection.MaliciousSiteDataKey {
+ store[key.dataType] = dataSet
+ }
+
+}
diff --git a/Tests/MaliciousSiteProtectionTests/Mocks/MockMaliciousSiteProtectionEmbeddedDataProvider.swift b/Tests/MaliciousSiteProtectionTests/Mocks/MockMaliciousSiteProtectionEmbeddedDataProvider.swift
new file mode 100644
index 000000000..10a3e2643
--- /dev/null
+++ b/Tests/MaliciousSiteProtectionTests/Mocks/MockMaliciousSiteProtectionEmbeddedDataProvider.swift
@@ -0,0 +1,81 @@
+//
+// MockMaliciousSiteProtectionEmbeddedDataProvider.swift
+//
+// Copyright © 2024 DuckDuckGo. All rights reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//
+
+import Foundation
+@testable import MaliciousSiteProtection
+
+final class MockMaliciousSiteProtectionEmbeddedDataProvider: MaliciousSiteProtection.EmbeddedDataProviding {
+ var embeddedRevision: Int = 65
+ var loadHashPrefixesCalled: Bool = false
+ var loadFilterSetCalled: Bool = true
+ var hashPrefixes: Set = [] {
+ didSet {
+ hashPrefixesData = try! JSONEncoder().encode(hashPrefixes)
+ }
+ }
+ var hashPrefixesData: Data!
+
+ var filterSet: Set = [] {
+ didSet {
+ filterSetData = try! JSONEncoder().encode(filterSet)
+ }
+ }
+ var filterSetData: Data!
+
+ init() {
+ hashPrefixes = Set(["aabb"])
+ filterSet = Set([Filter(hash: "dummyhash", regex: "dummyregex")])
+ }
+
+ func revision(for detectionKind: MaliciousSiteProtection.DataManager.StoredDataType) -> Int {
+ embeddedRevision
+ }
+
+ func url(for dataType: MaliciousSiteProtection.DataManager.StoredDataType) -> URL {
+ switch dataType {
+ case .filterSet:
+ self.loadFilterSetCalled = true
+ return URL(string: "filterSet")!
+ case .hashPrefixSet:
+ self.loadHashPrefixesCalled = true
+ return URL(string: "hashPrefixSet")!
+ }
+ }
+
+ func hash(for dataType: MaliciousSiteProtection.DataManager.StoredDataType) -> String {
+ let url = url(for: dataType)
+ let data = try! data(withContentsOf: url)
+ let sha = data.sha256
+ return sha
+ }
+
+ func data(withContentsOf url: URL) throws -> Data {
+ let data: Data
+ switch url.absoluteString {
+ case "filterSet":
+ self.loadFilterSetCalled = true
+ return filterSetData
+ case "hashPrefixSet":
+ self.loadHashPrefixesCalled = true
+ return hashPrefixesData
+ default:
+ fatalError("Unexpected url \(url.absoluteString)")
+ }
+ }
+
+}
diff --git a/Tests/PhishingDetectionTests/Mocks/PhishingDetectionUpdateManagerMock.swift b/Tests/MaliciousSiteProtectionTests/Mocks/MockPhishingDetectionUpdateManager.swift
similarity index 59%
rename from Tests/PhishingDetectionTests/Mocks/PhishingDetectionUpdateManagerMock.swift
rename to Tests/MaliciousSiteProtectionTests/Mocks/MockPhishingDetectionUpdateManager.swift
index d5ca12559..b49eac588 100644
--- a/Tests/PhishingDetectionTests/Mocks/PhishingDetectionUpdateManagerMock.swift
+++ b/Tests/MaliciousSiteProtectionTests/Mocks/MockPhishingDetectionUpdateManager.swift
@@ -1,5 +1,5 @@
//
-// PhishingDetectionUpdateManagerMock.swift
+// MockPhishingDetectionUpdateManager.swift
//
// Copyright © 2024 DuckDuckGo. All rights reserved.
//
@@ -17,27 +17,40 @@
//
import Foundation
-import PhishingDetection
+@testable import MaliciousSiteProtection
+
+class MockPhishingDetectionUpdateManager: MaliciousSiteProtection.UpdateManaging {
-public class MockPhishingDetectionUpdateManager: PhishingDetectionUpdateManaging {
var didUpdateFilterSet = false
var didUpdateHashPrefixes = false
+ var startPeriodicUpdatesCalled = false
var completionHandler: (() -> Void)?
- public func updateFilterSet() async {
+ func updateData(for key: some MaliciousSiteProtection.MaliciousSiteDataKey) async {
+ switch key.dataType {
+ case .filterSet: await updateFilterSet()
+ case .hashPrefixSet: await updateHashPrefixes()
+ }
+ }
+
+ func updateFilterSet() async {
didUpdateFilterSet = true
checkCompletion()
}
- public func updateHashPrefixes() async {
+ func updateHashPrefixes() async {
didUpdateHashPrefixes = true
checkCompletion()
}
- private func checkCompletion() {
+ func checkCompletion() {
if didUpdateFilterSet && didUpdateHashPrefixes {
completionHandler?()
}
}
+ public func startPeriodicUpdates() -> Task {
+ startPeriodicUpdatesCalled = true
+ return Task {}
+ }
}
diff --git a/Tests/PhishingDetectionTests/Resources/filterSet.json b/Tests/MaliciousSiteProtectionTests/Resources/phishingFilterSet.json
similarity index 100%
rename from Tests/PhishingDetectionTests/Resources/filterSet.json
rename to Tests/MaliciousSiteProtectionTests/Resources/phishingFilterSet.json
diff --git a/Tests/PhishingDetectionTests/Resources/hashPrefixes.json b/Tests/MaliciousSiteProtectionTests/Resources/phishingHashPrefixes.json
similarity index 100%
rename from Tests/PhishingDetectionTests/Resources/hashPrefixes.json
rename to Tests/MaliciousSiteProtectionTests/Resources/phishingHashPrefixes.json
diff --git a/Tests/NavigationTests/Helpers/NavigationResponderMock.swift b/Tests/NavigationTests/Helpers/NavigationResponderMock.swift
index d39a6ee44..fda1b2805 100644
--- a/Tests/NavigationTests/Helpers/NavigationResponderMock.swift
+++ b/Tests/NavigationTests/Helpers/NavigationResponderMock.swift
@@ -374,7 +374,6 @@ class NavigationResponderMock: NavigationResponder {
var onDidTerminate: (@MainActor (WKProcessTerminationReason?) -> Void)?
func webContentProcessDidTerminate(with reason: WKProcessTerminationReason?) {
- let event = append(.didTerminate(reason))
onDidTerminate?(reason)
}
diff --git a/Tests/NetworkingTests/v2/APIRequestV2Tests.swift b/Tests/NetworkingTests/v2/APIRequestV2Tests.swift
index 59eeadebb..4ec1b8b59 100644
--- a/Tests/NetworkingTests/v2/APIRequestV2Tests.swift
+++ b/Tests/NetworkingTests/v2/APIRequestV2Tests.swift
@@ -41,21 +41,18 @@ final class APIRequestV2Tests: XCTestCase {
cachePolicy: cachePolicy,
responseConstraints: constraints)
- guard let urlRequest = apiRequest?.urlRequest else {
- XCTFail("Nil URLRequest")
- return
- }
+ let urlRequest = apiRequest.urlRequest
XCTAssertEqual(urlRequest.url?.host(), url.host())
XCTAssertEqual(urlRequest.httpMethod, method.rawValue)
let urlComponents = URLComponents(string: urlRequest.url!.absoluteString)!
- XCTAssertTrue(urlComponents.queryItems!.contains(queryItems.toURLQueryItems()))
+ XCTAssertTrue(urlComponents.queryItems!.contains(URLQueryItem(name: "key", value: "value")))
XCTAssertEqual(urlRequest.allHTTPHeaderFields, headers.httpHeaders)
XCTAssertEqual(urlRequest.httpBody, body)
- XCTAssertEqual(apiRequest?.timeoutInterval, timeoutInterval)
+ XCTAssertEqual(apiRequest.timeoutInterval, timeoutInterval)
XCTAssertEqual(urlRequest.cachePolicy, cachePolicy)
- XCTAssertEqual(apiRequest?.responseConstraints, constraints)
+ XCTAssertEqual(apiRequest.responseConstraints, constraints)
}
func testURLRequestGeneration() {
@@ -75,16 +72,16 @@ final class APIRequestV2Tests: XCTestCase {
timeoutInterval: timeoutInterval,
cachePolicy: cachePolicy)
- let urlComponents = URLComponents(string: apiRequest!.urlRequest.url!.absoluteString)!
- XCTAssertTrue(urlComponents.queryItems!.contains(queryItems.toURLQueryItems()))
+ let urlComponents = URLComponents(string: apiRequest.urlRequest.url!.absoluteString)!
+ XCTAssertTrue(urlComponents.queryItems!.contains(URLQueryItem(name: "key", value: "value")))
XCTAssertNotNil(apiRequest)
- XCTAssertEqual(apiRequest?.urlRequest.url?.absoluteString, "https://www.example.com?key=value")
- XCTAssertEqual(apiRequest?.urlRequest.httpMethod, method.rawValue)
- XCTAssertEqual(apiRequest?.urlRequest.allHTTPHeaderFields, headers.httpHeaders)
- XCTAssertEqual(apiRequest?.urlRequest.httpBody, body)
- XCTAssertEqual(apiRequest?.urlRequest.timeoutInterval, timeoutInterval)
- XCTAssertEqual(apiRequest?.urlRequest.cachePolicy, cachePolicy)
+ XCTAssertEqual(apiRequest.urlRequest.url?.absoluteString, "https://www.example.com?key=value")
+ XCTAssertEqual(apiRequest.urlRequest.httpMethod, method.rawValue)
+ XCTAssertEqual(apiRequest.urlRequest.allHTTPHeaderFields, headers.httpHeaders)
+ XCTAssertEqual(apiRequest.urlRequest.httpBody, body)
+ XCTAssertEqual(apiRequest.urlRequest.timeoutInterval, timeoutInterval)
+ XCTAssertEqual(apiRequest.urlRequest.cachePolicy, cachePolicy)
}
func testDefaultValues() {
@@ -92,16 +89,13 @@ final class APIRequestV2Tests: XCTestCase {
let apiRequest = APIRequestV2(url: url)
let headers = APIRequestV2.HeadersV2()
- guard let urlRequest = apiRequest?.urlRequest else {
- XCTFail("Nil URLRequest")
- return
- }
+ let urlRequest = apiRequest.urlRequest
XCTAssertEqual(urlRequest.httpMethod, HTTPRequestMethod.get.rawValue)
XCTAssertEqual(urlRequest.timeoutInterval, 60.0)
XCTAssertEqual(headers.httpHeaders, urlRequest.allHTTPHeaderFields)
XCTAssertNil(urlRequest.httpBody)
XCTAssertEqual(urlRequest.cachePolicy.rawValue, 0)
- XCTAssertNil(apiRequest?.responseConstraints)
+ XCTAssertNil(apiRequest.responseConstraints)
}
func testAllowedQueryReservedCharacters() {
@@ -112,9 +106,10 @@ final class APIRequestV2Tests: XCTestCase {
queryItems: queryItems,
allowedQueryReservedCharacters: CharacterSet(charactersIn: ","))
- let urlString = apiRequest!.urlRequest.url!.absoluteString
- XCTAssertTrue(urlString == "https://www.example.com?k%2523e,y=val%2523ue")
+ let urlString = apiRequest.urlRequest.url!.absoluteString
+ XCTAssertEqual(urlString, "https://www.example.com?k%23e,y=val%23ue")
+
let urlComponents = URLComponents(string: urlString)!
- XCTAssertTrue(urlComponents.queryItems?.count == 1)
+ XCTAssertEqual(urlComponents.queryItems?.count, 1)
}
}
diff --git a/Tests/NetworkingTests/v2/APIServiceTests.swift b/Tests/NetworkingTests/v2/APIServiceTests.swift
index 9cae44323..730d6afbb 100644
--- a/Tests/NetworkingTests/v2/APIServiceTests.swift
+++ b/Tests/NetworkingTests/v2/APIServiceTests.swift
@@ -41,7 +41,7 @@ final class APIServiceTests: XCTestCase {
cachePolicy: .reloadIgnoringLocalAndRemoteCacheData,
responseConstraints: [APIResponseConstraints.allowHTTPNotModified,
APIResponseConstraints.requireETagHeader],
- allowedQueryReservedCharacters: CharacterSet(charactersIn: ","))!
+ allowedQueryReservedCharacters: CharacterSet(charactersIn: ","))
let apiService = DefaultAPIService()
let response = try await apiService.fetch(request: request)
let responseHTML: String = try response.decodeBody()
@@ -50,7 +50,7 @@ final class APIServiceTests: XCTestCase {
func disabled_testRealCallJSON() async throws {
// func testRealCallJSON() async throws {
- let request = APIRequestV2(url: HTTPURLResponse.testUrl)!
+ let request = APIRequestV2(url: HTTPURLResponse.testUrl)
let apiService = DefaultAPIService()
let result = try await apiService.fetch(request: request)
@@ -63,7 +63,7 @@ final class APIServiceTests: XCTestCase {
func disabled_testRealCallString() async throws {
// func testRealCallString() async throws {
- let request = APIRequestV2(url: HTTPURLResponse.testUrl)!
+ let request = APIRequestV2(url: HTTPURLResponse.testUrl)
let apiService = DefaultAPIService()
let result = try await apiService.fetch(request: request)
@@ -75,17 +75,16 @@ final class APIServiceTests: XCTestCase {
"qName2": "qValue2"]
MockURLProtocol.requestHandler = { request in
let urlComponents = URLComponents(string: request.url!.absoluteString)!
- XCTAssertTrue(urlComponents.queryItems!.contains(qItems.toURLQueryItems()))
+ XCTAssertTrue(urlComponents.queryItems!.contains(qItems.map { URLQueryItem(name: $0.key, value: $0.value) }))
return (HTTPURLResponse.ok, nil)
}
- let request = APIRequestV2(url: HTTPURLResponse.testUrl,
- queryItems: qItems)!
+ let request = APIRequestV2(url: HTTPURLResponse.testUrl, queryItems: qItems)
let apiService = DefaultAPIService(urlSession: mockURLSession)
_ = try await apiService.fetch(request: request)
}
func testURLRequestError() async throws {
- let request = APIRequestV2(url: HTTPURLResponse.testUrl)!
+ let request = APIRequestV2(url: HTTPURLResponse.testUrl)
enum TestError: Error {
case anError
@@ -111,7 +110,7 @@ final class APIServiceTests: XCTestCase {
func testResponseRequirementAllowHTTPNotModifiedSuccess() async throws {
let requirements = [APIResponseConstraints.allowHTTPNotModified ]
- let request = APIRequestV2(url: HTTPURLResponse.testUrl, responseConstraints: requirements)!
+ let request = APIRequestV2(url: HTTPURLResponse.testUrl, responseConstraints: requirements)
MockURLProtocol.requestHandler = { _ in ( HTTPURLResponse.notModified, Data()) }
@@ -122,7 +121,7 @@ final class APIServiceTests: XCTestCase {
}
func testResponseRequirementAllowHTTPNotModifiedFailure() async throws {
- let request = APIRequestV2(url: HTTPURLResponse.testUrl)!
+ let request = APIRequestV2(url: HTTPURLResponse.testUrl)
MockURLProtocol.requestHandler = { _ in ( HTTPURLResponse.notModified, Data()) }
@@ -147,7 +146,7 @@ final class APIServiceTests: XCTestCase {
let requirements: [APIResponseConstraints] = [
APIResponseConstraints.requireETagHeader
]
- let request = APIRequestV2(url: HTTPURLResponse.testUrl, responseConstraints: requirements)!
+ let request = APIRequestV2(url: HTTPURLResponse.testUrl, responseConstraints: requirements)
MockURLProtocol.requestHandler = { _ in ( HTTPURLResponse.ok, nil) } // HTTPURLResponse.ok contains etag
let apiService = DefaultAPIService(urlSession: mockURLSession)
@@ -158,7 +157,7 @@ final class APIServiceTests: XCTestCase {
func testResponseRequirementRequireETagHeaderFailure() async throws {
let requirements = [ APIResponseConstraints.requireETagHeader ]
- let request = APIRequestV2(url: HTTPURLResponse.testUrl, responseConstraints: requirements)!
+ let request = APIRequestV2(url: HTTPURLResponse.testUrl, responseConstraints: requirements)
MockURLProtocol.requestHandler = { _ in ( HTTPURLResponse.okNoEtag, nil) }
@@ -181,7 +180,7 @@ final class APIServiceTests: XCTestCase {
func testResponseRequirementRequireUserAgentSuccess() async throws {
let requirements = [ APIResponseConstraints.requireUserAgent ]
- let request = APIRequestV2(url: HTTPURLResponse.testUrl, responseConstraints: requirements)!
+ let request = APIRequestV2(url: HTTPURLResponse.testUrl, responseConstraints: requirements)
MockURLProtocol.requestHandler = { _ in
( HTTPURLResponse.okUserAgent, nil)
@@ -194,7 +193,7 @@ final class APIServiceTests: XCTestCase {
func testResponseRequirementRequireUserAgentFailure() async throws {
let requirements = [ APIResponseConstraints.requireUserAgent ]
- let request = APIRequestV2(url: HTTPURLResponse.testUrl, responseConstraints: requirements)!
+ let request = APIRequestV2(url: HTTPURLResponse.testUrl, responseConstraints: requirements)
MockURLProtocol.requestHandler = { _ in ( HTTPURLResponse.ok, nil) }
diff --git a/Tests/OnboardingTests/OnboardingSuggestionsViewModelsTests.swift b/Tests/OnboardingTests/OnboardingSuggestionsViewModelsTests.swift
index ed7927f4f..942543505 100644
--- a/Tests/OnboardingTests/OnboardingSuggestionsViewModelsTests.swift
+++ b/Tests/OnboardingTests/OnboardingSuggestionsViewModelsTests.swift
@@ -147,11 +147,11 @@ class CapturingOnboardingNavigationDelegate: OnboardingNavigationDelegate {
var suggestedSearchQuery: String?
var urlToNavigateTo: URL?
- func searchFor(_ query: String) {
+ func searchFromOnboarding(for query: String) {
suggestedSearchQuery = query
}
- func navigateTo(url: URL) {
+ func navigateFromOnboarding(to url: URL) {
urlToNavigateTo = url
}
}
diff --git a/Tests/PhishingDetectionTests/BackgroundActivitySchedulerTests.swift b/Tests/PhishingDetectionTests/BackgroundActivitySchedulerTests.swift
deleted file mode 100644
index 8d907efbd..000000000
--- a/Tests/PhishingDetectionTests/BackgroundActivitySchedulerTests.swift
+++ /dev/null
@@ -1,57 +0,0 @@
-//
-// BackgroundActivitySchedulerTests.swift
-//
-// Copyright © 2024 DuckDuckGo. All rights reserved.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-//
-import Foundation
-import XCTest
-@testable import PhishingDetection
-
-class BackgroundActivitySchedulerTests: XCTestCase {
- var scheduler: BackgroundActivityScheduler!
- var activityWasRun = false
-
- override func tearDown() {
- scheduler = nil
- super.tearDown()
- }
-
- func testStart() async throws {
- let expectation = self.expectation(description: "Activity should run")
- scheduler = BackgroundActivityScheduler(interval: 1, identifier: "test") {
- if !self.activityWasRun {
- self.activityWasRun = true
- expectation.fulfill()
- }
- }
- await scheduler.start()
- await fulfillment(of: [expectation], timeout: 2)
- XCTAssertTrue(activityWasRun)
- }
-
- func testRepeats() async throws {
- let expectation = self.expectation(description: "Activity should repeat")
- var runCount = 0
- scheduler = BackgroundActivityScheduler(interval: 1, identifier: "test") {
- runCount += 1
- if runCount == 2 {
- expectation.fulfill()
- }
- }
- await scheduler.start()
- await fulfillment(of: [expectation], timeout: 3)
- XCTAssertEqual(runCount, 2)
- }
-}
diff --git a/Tests/PhishingDetectionTests/Mocks/PhishingDetectionClientMock.swift b/Tests/PhishingDetectionTests/Mocks/PhishingDetectionClientMock.swift
deleted file mode 100644
index 9f39598b2..000000000
--- a/Tests/PhishingDetectionTests/Mocks/PhishingDetectionClientMock.swift
+++ /dev/null
@@ -1,84 +0,0 @@
-//
-// PhishingDetectionClientMock.swift
-//
-// Copyright © 2024 DuckDuckGo. All rights reserved.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-//
-
-import Foundation
-import PhishingDetection
-
-public class MockPhishingDetectionClient: PhishingDetectionClientProtocol {
- public var updateHashPrefixesWasCalled: Bool = false
- public var updateFilterSetsWasCalled: Bool = false
-
- private var filterRevisions: [Int: FilterSetResponse] = [
- 0: FilterSetResponse(insert: [
- Filter(hashValue: "testhash1", regex: ".*example.*"),
- Filter(hashValue: "testhash2", regex: ".*test.*")
- ], delete: [], revision: 0, replace: true),
- 1: FilterSetResponse(insert: [
- Filter(hashValue: "testhash3", regex: ".*test.*")
- ], delete: [
- Filter(hashValue: "testhash1", regex: ".*example.*"),
- ], revision: 1, replace: false),
- 2: FilterSetResponse(insert: [
- Filter(hashValue: "testhash4", regex: ".*test.*")
- ], delete: [
- Filter(hashValue: "testhash2", regex: ".*test.*"),
- ], revision: 2, replace: false),
- 4: FilterSetResponse(insert: [
- Filter(hashValue: "testhash5", regex: ".*test.*")
- ], delete: [
- Filter(hashValue: "testhash3", regex: ".*test.*"),
- ], revision: 4, replace: false)
- ]
-
- private var hashPrefixRevisions: [Int: HashPrefixResponse] = [
- 0: HashPrefixResponse(insert: [
- "aa00bb11",
- "bb00cc11",
- "cc00dd11",
- "dd00ee11",
- "a379a6f6"
- ], delete: [], revision: 0, replace: true),
- 1: HashPrefixResponse(insert: ["93e2435e"], delete: [
- "cc00dd11",
- "dd00ee11",
- ], revision: 1, replace: false),
- 2: HashPrefixResponse(insert: ["c0be0d0a6"], delete: [
- "bb00cc11",
- ], revision: 2, replace: false),
- 4: HashPrefixResponse(insert: ["a379a6f6"], delete: [
- "aa00bb11",
- ], revision: 4, replace: false)
- ]
-
- public func getFilterSet(revision: Int) async -> FilterSetResponse {
- updateFilterSetsWasCalled = true
- return filterRevisions[revision] ?? FilterSetResponse(insert: [], delete: [], revision: revision, replace: false)
- }
-
- public func getHashPrefixes(revision: Int) async -> HashPrefixResponse {
- updateHashPrefixesWasCalled = true
- return hashPrefixRevisions[revision] ?? HashPrefixResponse(insert: [], delete: [], revision: revision, replace: false)
- }
-
- public func getMatches(hashPrefix: String) async -> [Match] {
- return [
- Match(hostname: "example.com", url: "https://example.com/mal", regex: ".", hash: "a379a6f6eeafb9a55e378c118034e2751e682fab9f2d30ab13d2125586ce1947"),
- Match(hostname: "test.com", url: "https://test.com/mal", regex: ".*test.*", hash: "aa00bb11aa00cc11bb00cc11")
- ]
- }
-}
diff --git a/Tests/PhishingDetectionTests/Mocks/PhishingDetectionDataProviderMock.swift b/Tests/PhishingDetectionTests/Mocks/PhishingDetectionDataProviderMock.swift
deleted file mode 100644
index 79d4d5d6b..000000000
--- a/Tests/PhishingDetectionTests/Mocks/PhishingDetectionDataProviderMock.swift
+++ /dev/null
@@ -1,47 +0,0 @@
-//
-// PhishingDetectionDataProviderMock.swift
-//
-// Copyright © 2024 DuckDuckGo. All rights reserved.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-//
-
-import Foundation
-import PhishingDetection
-
-public class MockPhishingDetectionDataProvider: PhishingDetectionDataProviding {
- public var embeddedRevision: Int = 65
- var loadHashPrefixesCalled: Bool = false
- var loadFilterSetCalled: Bool = true
- var hashPrefixes: Set = ["aabb"]
- var filterSet: Set = [Filter(hashValue: "dummyhash", regex: "dummyregex")]
-
- public func shouldReturnFilterSet(set: Set) {
- self.filterSet = set
- }
-
- public func shouldReturnHashPrefixes(set: Set) {
- self.hashPrefixes = set
- }
-
- public func loadEmbeddedFilterSet() -> Set {
- self.loadHashPrefixesCalled = true
- return self.filterSet
- }
-
- public func loadEmbeddedHashPrefixes() -> Set {
- self.loadFilterSetCalled = true
- return self.hashPrefixes
- }
-
-}
diff --git a/Tests/PhishingDetectionTests/Mocks/PhishingDetectionDataStoreMock.swift b/Tests/PhishingDetectionTests/Mocks/PhishingDetectionDataStoreMock.swift
deleted file mode 100644
index 54521419c..000000000
--- a/Tests/PhishingDetectionTests/Mocks/PhishingDetectionDataStoreMock.swift
+++ /dev/null
@@ -1,44 +0,0 @@
-//
-// PhishingDetectionDataStoreMock.swift
-//
-// Copyright © 2024 DuckDuckGo. All rights reserved.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-//
-
-import Foundation
-import PhishingDetection
-
-public class MockPhishingDetectionDataStore: PhishingDetectionDataSaving {
- public var filterSet: Set
- public var hashPrefixes: Set
- public var currentRevision: Int
-
- public init() {
- filterSet = Set()
- hashPrefixes = Set()
- currentRevision = 0
- }
-
- public func saveFilterSet(set: Set) {
- filterSet = set
- }
-
- public func saveHashPrefixes(set: Set) {
- hashPrefixes = set
- }
-
- public func saveRevision(_ revision: Int) {
- currentRevision = revision
- }
-}
diff --git a/Tests/PhishingDetectionTests/PhishingDetectionClientTests.swift b/Tests/PhishingDetectionTests/PhishingDetectionClientTests.swift
deleted file mode 100644
index 6826c86d6..000000000
--- a/Tests/PhishingDetectionTests/PhishingDetectionClientTests.swift
+++ /dev/null
@@ -1,125 +0,0 @@
-//
-// PhishingDetectionClientTests.swift
-//
-// Copyright © 2024 DuckDuckGo. All rights reserved.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-//
-import Foundation
-import XCTest
-@testable import PhishingDetection
-
-final class PhishingDetectionAPIClientTests: XCTestCase {
-
- var mockSession: MockURLSession!
- var client: PhishingDetectionAPIClient!
-
- override func setUp() {
- super.setUp()
- mockSession = MockURLSession()
- client = PhishingDetectionAPIClient(environment: .staging, session: mockSession)
- }
-
- override func tearDown() {
- mockSession = nil
- client = nil
- super.tearDown()
- }
-
- func testGetFilterSetSuccess() async {
- // Given
- let insertFilter = Filter(hashValue: "a379a6f6eeafb9a55e378c118034e2751e682fab9f2d30ab13d2125586ce1947", regex: ".")
- let deleteFilter = Filter(hashValue: "6a929cd0b3ba4677eaedf1b2bdaf3ff89281cca94f688c83103bc9a676aea46d", regex: "(?i)^https?\\:\\/\\/[\\w\\-\\.]+(?:\\:(?:80|443))?")
- let expectedResponse = FilterSetResponse(insert: [insertFilter], delete: [deleteFilter], revision: 1, replace: false)
- mockSession.data = try? JSONEncoder().encode(expectedResponse)
- mockSession.response = HTTPURLResponse(url: client.filterSetURL, statusCode: 200, httpVersion: nil, headerFields: nil)
-
- // When
- let response = await client.getFilterSet(revision: 1)
-
- // Then
- XCTAssertEqual(response, expectedResponse)
- }
-
- func testGetHashPrefixesSuccess() async {
- // Given
- let expectedResponse = HashPrefixResponse(insert: ["abc"], delete: ["def"], revision: 1, replace: false)
- mockSession.data = try? JSONEncoder().encode(expectedResponse)
- mockSession.response = HTTPURLResponse(url: client.hashPrefixURL, statusCode: 200, httpVersion: nil, headerFields: nil)
-
- // When
- let response = await client.getHashPrefixes(revision: 1)
-
- // Then
- XCTAssertEqual(response, expectedResponse)
- }
-
- func testGetMatchesSuccess() async {
- // Given
- let expectedResponse = MatchResponse(matches: [Match(hostname: "example.com", url: "https://example.com/test", regex: ".", hash: "a379a6f6eeafb9a55e378c118034e2751e682fab9f2d30ab13d2125586ce1947")])
- mockSession.data = try? JSONEncoder().encode(expectedResponse)
- mockSession.response = HTTPURLResponse(url: client.matchesURL, statusCode: 200, httpVersion: nil, headerFields: nil)
-
- // When
- let response = await client.getMatches(hashPrefix: "abc")
-
- // Then
- XCTAssertEqual(response, expectedResponse.matches)
- }
-
- func testGetFilterSetInvalidURL() async {
- // Given
- let invalidRevision = -1
-
- // When
- let response = await client.getFilterSet(revision: invalidRevision)
-
- // Then
- XCTAssertEqual(response, FilterSetResponse(insert: [], delete: [], revision: invalidRevision, replace: false))
- }
-
- func testGetHashPrefixesInvalidURL() async {
- // Given
- let invalidRevision = -1
-
- // When
- let response = await client.getHashPrefixes(revision: invalidRevision)
-
- // Then
- XCTAssertEqual(response, HashPrefixResponse(insert: [], delete: [], revision: invalidRevision, replace: false))
- }
-
- func testGetMatchesInvalidURL() async {
- // Given
- let invalidHashPrefix = ""
-
- // When
- let response = await client.getMatches(hashPrefix: invalidHashPrefix)
-
- // Then
- XCTAssertTrue(response.isEmpty)
- }
-}
-
-class MockURLSession: URLSessionProtocol {
- var data: Data?
- var response: URLResponse?
- var error: Error?
-
- func data(for request: URLRequest) async throws -> (Data, URLResponse) {
- if let error = error {
- throw error
- }
- return (data ?? Data(), response ?? URLResponse())
- }
-}
diff --git a/Tests/PhishingDetectionTests/PhishingDetectionDataActivitiesTests.swift b/Tests/PhishingDetectionTests/PhishingDetectionDataActivitiesTests.swift
deleted file mode 100644
index 583f94789..000000000
--- a/Tests/PhishingDetectionTests/PhishingDetectionDataActivitiesTests.swift
+++ /dev/null
@@ -1,48 +0,0 @@
-//
-// PhishingDetectionDataActivitiesTests.swift
-//
-// Copyright © 2024 DuckDuckGo. All rights reserved.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-//
-
-import Foundation
-import XCTest
-@testable import PhishingDetection
-
-class PhishingDetectionDataActivitiesTests: XCTestCase {
- var mockUpdateManager: MockPhishingDetectionUpdateManager!
- var activities: PhishingDetectionDataActivities!
-
- override func setUp() {
- super.setUp()
- mockUpdateManager = MockPhishingDetectionUpdateManager()
- activities = PhishingDetectionDataActivities(hashPrefixInterval: 1, filterSetInterval: 1, phishingDetectionDataProvider: MockPhishingDetectionDataProvider(), updateManager: mockUpdateManager)
- }
-
- func testUpdateHashPrefixesAndFilterSetRuns() async {
- let expectation = XCTestExpectation(description: "updateHashPrefixes and updateFilterSet completes")
-
- mockUpdateManager.completionHandler = {
- expectation.fulfill()
- }
-
- activities.start()
-
- await fulfillment(of: [expectation], timeout: 10.0)
-
- XCTAssertTrue(mockUpdateManager.didUpdateHashPrefixes)
- XCTAssertTrue(mockUpdateManager.didUpdateFilterSet)
-
- }
-}
diff --git a/Tests/PhishingDetectionTests/PhishingDetectionDataProviderTest.swift b/Tests/PhishingDetectionTests/PhishingDetectionDataProviderTest.swift
deleted file mode 100644
index 547f2dce8..000000000
--- a/Tests/PhishingDetectionTests/PhishingDetectionDataProviderTest.swift
+++ /dev/null
@@ -1,52 +0,0 @@
-//
-// PhishingDetectionDataProviderTest.swift
-//
-// Copyright © 2024 DuckDuckGo. All rights reserved.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-//
-import Foundation
-import XCTest
-@testable import PhishingDetection
-
-class PhishingDetectionDataProviderTest: XCTestCase {
- var filterSetURL: URL!
- var hashPrefixURL: URL!
- var dataProvider: PhishingDetectionDataProvider!
-
- override func setUp() {
- super.setUp()
- filterSetURL = Bundle.module.url(forResource: "filterSet", withExtension: "json")!
- hashPrefixURL = Bundle.module.url(forResource: "hashPrefixes", withExtension: "json")!
- }
-
- override func tearDown() {
- filterSetURL = nil
- hashPrefixURL = nil
- dataProvider = nil
- super.tearDown()
- }
-
- func testDataProviderLoadsJSON() {
- dataProvider = PhishingDetectionDataProvider(revision: 0, filterSetURL: filterSetURL, filterSetDataSHA: "4fd2868a4f264501ec175ab866504a2a96c8d21a3b5195b405a4a83b51eae504", hashPrefixURL: hashPrefixURL, hashPrefixDataSHA: "21b047a9950fcaf86034a6b16181e18815cb8d276386d85c8977ca8c5f8aa05f")
- let expectedFilter = Filter(hashValue: "e4753ddad954dafd4ff4ef67f82b3c1a2db6ef4a51bda43513260170e558bd13", regex: "(?i)^https?\\:\\/\\/privacy-test-pages\\.site(?:\\:(?:80|443))?\\/security\\/badware\\/phishing\\.html$")
- XCTAssertTrue(dataProvider.loadEmbeddedFilterSet().contains(expectedFilter))
- XCTAssertTrue(dataProvider.loadEmbeddedHashPrefixes().contains("012db806"))
- }
-
- func testReturnsNoneWhenSHAMismatch() {
- dataProvider = PhishingDetectionDataProvider(revision: 0, filterSetURL: filterSetURL, filterSetDataSHA: "xx0", hashPrefixURL: hashPrefixURL, hashPrefixDataSHA: "00x")
- XCTAssertTrue(dataProvider.loadEmbeddedFilterSet().isEmpty)
- XCTAssertTrue(dataProvider.loadEmbeddedHashPrefixes().isEmpty)
- }
-}
diff --git a/Tests/PhishingDetectionTests/PhishingDetectionDataStoreTests.swift b/Tests/PhishingDetectionTests/PhishingDetectionDataStoreTests.swift
deleted file mode 100644
index 79e9fb500..000000000
--- a/Tests/PhishingDetectionTests/PhishingDetectionDataStoreTests.swift
+++ /dev/null
@@ -1,197 +0,0 @@
-//
-// PhishingDetectionDataStoreTests.swift
-//
-// Copyright © 2024 DuckDuckGo. All rights reserved.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-//
-
-import Foundation
-
-import XCTest
-@testable import PhishingDetection
-
-class PhishingDetectionDataStoreTests: XCTestCase {
- var mockDataProvider: MockPhishingDetectionDataProvider!
- let datasetFiles: [String] = ["hashPrefixes.json", "filterSet.json", "revision.txt"]
- var dataStore: PhishingDetectionDataStore!
- var fileStorageManager: FileStorageManager!
-
- override func setUp() {
- super.setUp()
- mockDataProvider = MockPhishingDetectionDataProvider()
- fileStorageManager = MockPhishingFileStorageManager()
- dataStore = PhishingDetectionDataStore(dataProvider: mockDataProvider, fileStorageManager: fileStorageManager)
- }
-
- override func tearDown() {
- mockDataProvider = nil
- dataStore = nil
- super.tearDown()
- }
-
- func clearDatasets() {
- for fileName in datasetFiles {
- let emptyData = Data()
- fileStorageManager.write(data: emptyData, to: fileName)
- }
- }
-
- func testWhenNoDataSavedThenProviderDataReturned() async {
- clearDatasets()
- let expectedFilerSet = Set([Filter(hashValue: "some", regex: "some")])
- let expectedHashPrefix = Set(["sassa"])
- mockDataProvider.shouldReturnFilterSet(set: expectedFilerSet)
- mockDataProvider.shouldReturnHashPrefixes(set: expectedHashPrefix)
-
- let actualFilterSet = dataStore.filterSet
- let actualHashPrefix = dataStore.hashPrefixes
-
- XCTAssertEqual(actualFilterSet, expectedFilerSet)
- XCTAssertEqual(actualHashPrefix, expectedHashPrefix)
- }
-
- func testWhenEmbeddedRevisionNewerThanOnDisk_ThenLoadEmbedded() async {
- let encoder = JSONEncoder()
- // On Disk Data Setup
- fileStorageManager.write(data: "1".utf8data, to: "revision.txt")
- let onDiskFilterSet = Set([Filter(hashValue: "other", regex: "other")])
- let filterSetData = try! encoder.encode(Array(onDiskFilterSet))
- let onDiskHashPrefix = Set(["faffa"])
- let hashPrefixData = try! encoder.encode(Array(onDiskHashPrefix))
- fileStorageManager.write(data: filterSetData, to: "filterSet.json")
- fileStorageManager.write(data: hashPrefixData, to: "hashPrefixes.json")
-
- // Embedded Data Setup
- mockDataProvider.embeddedRevision = 5
- let embeddedFilterSet = Set([Filter(hashValue: "some", regex: "some")])
- let embeddedHashPrefix = Set(["sassa"])
- mockDataProvider.shouldReturnFilterSet(set: embeddedFilterSet)
- mockDataProvider.shouldReturnHashPrefixes(set: embeddedHashPrefix)
-
- let actualRevision = dataStore.currentRevision
- let actualFilterSet = dataStore.filterSet
- let actualHashPrefix = dataStore.hashPrefixes
-
- XCTAssertEqual(actualFilterSet, embeddedFilterSet)
- XCTAssertEqual(actualHashPrefix, embeddedHashPrefix)
- XCTAssertEqual(actualRevision, 5)
- }
-
- func testWhenEmbeddedRevisionOlderThanOnDisk_ThenDontLoadEmbedded() async {
- let encoder = JSONEncoder()
- // On Disk Data Setup
- fileStorageManager.write(data: "6".utf8data, to: "revision.txt")
- let onDiskFilterSet = Set([Filter(hashValue: "other", regex: "other")])
- let filterSetData = try! encoder.encode(Array(onDiskFilterSet))
- let onDiskHashPrefix = Set(["faffa"])
- let hashPrefixData = try! encoder.encode(Array(onDiskHashPrefix))
- fileStorageManager.write(data: filterSetData, to: "filterSet.json")
- fileStorageManager.write(data: hashPrefixData, to: "hashPrefixes.json")
-
- // Embedded Data Setup
- mockDataProvider.embeddedRevision = 1
- let embeddedFilterSet = Set([Filter(hashValue: "some", regex: "some")])
- let embeddedHashPrefix = Set(["sassa"])
- mockDataProvider.shouldReturnFilterSet(set: embeddedFilterSet)
- mockDataProvider.shouldReturnHashPrefixes(set: embeddedHashPrefix)
-
- let actualRevision = dataStore.currentRevision
- let actualFilterSet = dataStore.filterSet
- let actualHashPrefix = dataStore.hashPrefixes
-
- XCTAssertEqual(actualFilterSet, onDiskFilterSet)
- XCTAssertEqual(actualHashPrefix, onDiskHashPrefix)
- XCTAssertEqual(actualRevision, 6)
- }
-
- func testWriteAndLoadData() async {
- // Get and write data
- let expectedHashPrefixes = Set(["aabb"])
- let expectedFilterSet = Set([Filter(hashValue: "dummyhash", regex: "dummyregex")])
- let expectedRevision = 65
-
- dataStore.saveHashPrefixes(set: expectedHashPrefixes)
- dataStore.saveFilterSet(set: expectedFilterSet)
- dataStore.saveRevision(expectedRevision)
-
- XCTAssertEqual(dataStore.filterSet, expectedFilterSet)
- XCTAssertEqual(dataStore.hashPrefixes, expectedHashPrefixes)
- XCTAssertEqual(dataStore.currentRevision, expectedRevision)
-
- // Test decode JSON data to expected types
- let storedHashPrefixesData = fileStorageManager.read(from: "hashPrefixes.json")
- let storedFilterSetData = fileStorageManager.read(from: "filterSet.json")
- let storedRevisionData = fileStorageManager.read(from: "revision.txt")
-
- let decoder = JSONDecoder()
- if let storedHashPrefixes = try? decoder.decode(Set.self, from: storedHashPrefixesData!),
- let storedFilterSet = try? decoder.decode(Set.self, from: storedFilterSetData!),
- let storedRevisionString = String(data: storedRevisionData!, encoding: .utf8),
- let storedRevision = Int(storedRevisionString.trimmingCharacters(in: .whitespacesAndNewlines)) {
-
- XCTAssertEqual(storedFilterSet, expectedFilterSet)
- XCTAssertEqual(storedHashPrefixes, expectedHashPrefixes)
- XCTAssertEqual(storedRevision, expectedRevision)
- } else {
- XCTFail("Failed to decode stored PhishingDetection data")
- }
- }
-
- func testLazyLoadingDoesNotReturnStaleData() async {
- clearDatasets()
-
- // Set up initial data
- let initialFilterSet = Set([Filter(hashValue: "initial", regex: "initial")])
- let initialHashPrefixes = Set(["initialPrefix"])
- mockDataProvider.shouldReturnFilterSet(set: initialFilterSet)
- mockDataProvider.shouldReturnHashPrefixes(set: initialHashPrefixes)
-
- // Access the lazy-loaded properties to trigger loading
- let loadedFilterSet = dataStore.filterSet
- let loadedHashPrefixes = dataStore.hashPrefixes
-
- // Validate loaded data matches initial data
- XCTAssertEqual(loadedFilterSet, initialFilterSet)
- XCTAssertEqual(loadedHashPrefixes, initialHashPrefixes)
-
- // Update in-memory data
- let updatedFilterSet = Set([Filter(hashValue: "updated", regex: "updated")])
- let updatedHashPrefixes = Set(["updatedPrefix"])
- dataStore.saveFilterSet(set: updatedFilterSet)
- dataStore.saveHashPrefixes(set: updatedHashPrefixes)
-
- // Access lazy-loaded properties again
- let reloadedFilterSet = dataStore.filterSet
- let reloadedHashPrefixes = dataStore.hashPrefixes
-
- // Validate reloaded data matches updated data
- XCTAssertEqual(reloadedFilterSet, updatedFilterSet)
- XCTAssertEqual(reloadedHashPrefixes, updatedHashPrefixes)
-
- // Validate on-disk data is also updated
- let storedFilterSetData = fileStorageManager.read(from: "filterSet.json")
- let storedHashPrefixesData = fileStorageManager.read(from: "hashPrefixes.json")
-
- let decoder = JSONDecoder()
- if let storedFilterSet = try? decoder.decode(Set.self, from: storedFilterSetData!),
- let storedHashPrefixes = try? decoder.decode(Set.self, from: storedHashPrefixesData!) {
-
- XCTAssertEqual(storedFilterSet, updatedFilterSet)
- XCTAssertEqual(storedHashPrefixes, updatedHashPrefixes)
- } else {
- XCTFail("Failed to decode stored PhishingDetection data after update")
- }
- }
-
-}
diff --git a/Tests/PhishingDetectionTests/PhishingDetectionUpdateManagerTests.swift b/Tests/PhishingDetectionTests/PhishingDetectionUpdateManagerTests.swift
deleted file mode 100644
index 6fec6c134..000000000
--- a/Tests/PhishingDetectionTests/PhishingDetectionUpdateManagerTests.swift
+++ /dev/null
@@ -1,155 +0,0 @@
-//
-// PhishingDetectionUpdateManagerTests.swift
-//
-// Copyright © 2024 DuckDuckGo. All rights reserved.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-//
-
-import Foundation
-
-import XCTest
-@testable import PhishingDetection
-
-class PhishingDetectionUpdateManagerTests: XCTestCase {
- var updateManager: PhishingDetectionUpdateManager!
- var dataStore: PhishingDetectionDataSaving!
- var mockClient: MockPhishingDetectionClient!
-
- override func setUp() async throws {
- try await super.setUp()
- mockClient = MockPhishingDetectionClient()
- dataStore = MockPhishingDetectionDataStore()
- updateManager = PhishingDetectionUpdateManager(client: mockClient, dataStore: dataStore)
- dataStore.saveRevision(0)
- await updateManager.updateFilterSet()
- await updateManager.updateHashPrefixes()
- }
-
- override func tearDown() {
- updateManager = nil
- dataStore = nil
- mockClient = nil
- super.tearDown()
- }
-
- func testUpdateHashPrefixes() async {
- await updateManager.updateHashPrefixes()
- XCTAssertFalse(dataStore.hashPrefixes.isEmpty, "Hash prefixes should not be empty after update.")
- XCTAssertEqual(dataStore.hashPrefixes, [
- "aa00bb11",
- "bb00cc11",
- "cc00dd11",
- "dd00ee11",
- "a379a6f6"
- ])
- }
-
- func testUpdateFilterSet() async {
- await updateManager.updateFilterSet()
- XCTAssertEqual(dataStore.filterSet, [
- Filter(hashValue: "testhash1", regex: ".*example.*"),
- Filter(hashValue: "testhash2", regex: ".*test.*")
- ])
- }
-
- func testRevision1AddsAndDeletesData() async {
- let expectedFilterSet: Set = [
- Filter(hashValue: "testhash2", regex: ".*test.*"),
- Filter(hashValue: "testhash3", regex: ".*test.*")
- ]
- let expectedHashPrefixes: Set = [
- "aa00bb11",
- "bb00cc11",
- "a379a6f6",
- "93e2435e"
- ]
-
- // Save revision and update the filter set and hash prefixes
- dataStore.saveRevision(1)
- await updateManager.updateFilterSet()
- await updateManager.updateHashPrefixes()
-
- XCTAssertEqual(dataStore.filterSet, expectedFilterSet, "Filter set should match the expected set after update.")
- XCTAssertEqual(dataStore.hashPrefixes, expectedHashPrefixes, "Hash prefixes should match the expected set after update.")
- }
-
- func testRevision2AddsAndDeletesData() async {
- let expectedFilterSet: Set = [
- Filter(hashValue: "testhash4", regex: ".*test.*"),
- Filter(hashValue: "testhash1", regex: ".*example.*")
- ]
- let expectedHashPrefixes: Set = [
- "aa00bb11",
- "a379a6f6",
- "c0be0d0a6",
- "dd00ee11",
- "cc00dd11"
- ]
-
- // Save revision and update the filter set and hash prefixes
- dataStore.saveRevision(2)
- await updateManager.updateFilterSet()
- await updateManager.updateHashPrefixes()
-
- XCTAssertEqual(dataStore.filterSet, expectedFilterSet, "Filter set should match the expected set after update.")
- XCTAssertEqual(dataStore.hashPrefixes, expectedHashPrefixes, "Hash prefixes should match the expected set after update.")
- }
-
- func testRevision3AddsAndDeletesNothing() async {
- let expectedFilterSet = dataStore.filterSet
- let expectedHashPrefixes = dataStore.hashPrefixes
-
- // Save revision and update the filter set and hash prefixes
- dataStore.saveRevision(3)
- await updateManager.updateFilterSet()
- await updateManager.updateHashPrefixes()
-
- XCTAssertEqual(dataStore.filterSet, expectedFilterSet, "Filter set should match the expected set after update.")
- XCTAssertEqual(dataStore.hashPrefixes, expectedHashPrefixes, "Hash prefixes should match the expected set after update.")
- }
-
- func testRevision4AddsAndDeletesData() async {
- let expectedFilterSet: Set = [
- Filter(hashValue: "testhash2", regex: ".*test.*"),
- Filter(hashValue: "testhash1", regex: ".*example.*"),
- Filter(hashValue: "testhash5", regex: ".*test.*")
- ]
- let expectedHashPrefixes: Set = [
- "a379a6f6",
- "dd00ee11",
- "cc00dd11",
- "bb00cc11"
- ]
-
- // Save revision and update the filter set and hash prefixes
- dataStore.saveRevision(4)
- await updateManager.updateFilterSet()
- await updateManager.updateHashPrefixes()
-
- XCTAssertEqual(dataStore.filterSet, expectedFilterSet, "Filter set should match the expected set after update.")
- XCTAssertEqual(dataStore.hashPrefixes, expectedHashPrefixes, "Hash prefixes should match the expected set after update.")
- }
-}
-
-class MockPhishingFileStorageManager: FileStorageManager {
- private var data: [String: Data] = [:]
-
- func write(data: Data, to filename: String) {
- self.data[filename] = data
- }
-
- func read(from filename: String) -> Data? {
- return data[filename]
- }
-}
diff --git a/Tests/PhishingDetectionTests/PhishingDetectorTests.swift b/Tests/PhishingDetectionTests/PhishingDetectorTests.swift
deleted file mode 100644
index d2ef4a02e..000000000
--- a/Tests/PhishingDetectionTests/PhishingDetectorTests.swift
+++ /dev/null
@@ -1,104 +0,0 @@
-//
-// PhishingDetectorTests.swift
-//
-// Copyright © 2024 DuckDuckGo. All rights reserved.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-//
-import Foundation
-import XCTest
-@testable import PhishingDetection
-
-class IsMaliciousTests: XCTestCase {
-
- private var mockAPIClient: MockPhishingDetectionClient!
- private var mockDataStore: MockPhishingDetectionDataStore!
- private var mockEventMapping: MockEventMapping!
- private var detector: PhishingDetector!
-
- override func setUp() {
- super.setUp()
- mockAPIClient = MockPhishingDetectionClient()
- mockDataStore = MockPhishingDetectionDataStore()
- mockEventMapping = MockEventMapping()
- detector = PhishingDetector(apiClient: mockAPIClient, dataStore: mockDataStore, eventMapping: mockEventMapping)
- }
-
- override func tearDown() {
- mockAPIClient = nil
- mockDataStore = nil
- mockEventMapping = nil
- detector = nil
- super.tearDown()
- }
-
- func testIsMaliciousWithLocalFilterHit() async {
- let filter = Filter(hashValue: "255a8a793097aeea1f06a19c08cde28db0eb34c660c6e4e7480c9525d034b16d", regex: ".*malicious.*")
- mockDataStore.filterSet = Set([filter])
- mockDataStore.hashPrefixes = Set(["255a8a79"])
-
- let url = URL(string: "https://malicious.com/")!
-
- let result = await detector.isMalicious(url: url)
-
- XCTAssertTrue(result)
- }
-
- func testIsMaliciousWithApiMatch() async {
- mockDataStore.filterSet = Set()
- mockDataStore.hashPrefixes = ["a379a6f6"]
-
- let url = URL(string: "https://example.com/mal")!
-
- let result = await detector.isMalicious(url: url)
-
- XCTAssertTrue(result)
- }
-
- func testIsMaliciousWithHashPrefixMatch() async {
- let filter = Filter(hashValue: "notamatch", regex: ".*malicious.*")
- mockDataStore.filterSet = [filter]
- mockDataStore.hashPrefixes = ["4c64eb24"] // matches safe.com
-
- let url = URL(string: "https://safe.com")!
-
- let result = await detector.isMalicious(url: url)
-
- XCTAssertFalse(result)
- }
-
- func testIsMaliciousWithFullHashMatch() async {
- // 4c64eb2468bcd3e113b37167e6b819aeccf550f974a6082ef17fb74ca68e823b
- let filter = Filter(hashValue: "4c64eb2468bcd3e113b37167e6b819aeccf550f974a6082ef17fb74ca68e823b", regex: "https://safe.com/maliciousURI")
- mockDataStore.filterSet = [filter]
- mockDataStore.hashPrefixes = ["4c64eb24"]
-
- let url = URL(string: "https://safe.com")!
-
- let result = await detector.isMalicious(url: url)
-
- XCTAssertFalse(result)
- }
-
- func testIsMaliciousWithNoHashPrefixMatch() async {
- let filter = Filter(hashValue: "testHash", regex: ".*malicious.*")
- mockDataStore.filterSet = [filter]
- mockDataStore.hashPrefixes = ["testPrefix"]
-
- let url = URL(string: "https://safe.com")!
-
- let result = await detector.isMalicious(url: url)
-
- XCTAssertFalse(result)
- }
-}
diff --git a/Tests/PrivacyDashboardTests/PrivacyDashboardControllerTests.swift b/Tests/PrivacyDashboardTests/PrivacyDashboardControllerTests.swift
index 4c03464fa..867e7b888 100644
--- a/Tests/PrivacyDashboardTests/PrivacyDashboardControllerTests.swift
+++ b/Tests/PrivacyDashboardTests/PrivacyDashboardControllerTests.swift
@@ -260,13 +260,13 @@ final class PrivacyDashboardControllerTests: XCTestCase {
func testWhenIsPhishingSetThenJavaScriptEvaluatedWithCorrectString() {
let expectation = XCTestExpectation()
- let privacyInfo = PrivacyInfo(url: URL(string: "someurl.com")!, parentEntity: nil, protectionStatus: .init(unprotectedTemporary: false, enabledFeatures: [], allowlisted: true, denylisted: true), isPhishing: false)
+ let privacyInfo = PrivacyInfo(url: URL(string: "someurl.com")!, parentEntity: nil, protectionStatus: .init(unprotectedTemporary: false, enabledFeatures: [], allowlisted: true, denylisted: true), malicousSiteThreatKind: .none)
makePrivacyDashboardController(entryPoint: .dashboard, privacyInfo: privacyInfo)
let config = WKWebViewConfiguration()
let mockWebView = MockWebView(frame: .zero, configuration: config, expectation: expectation)
privacyDashboardController.webView = mockWebView
- privacyDashboardController.privacyInfo!.isPhishing = true
+ privacyDashboardController.privacyInfo!.malicousSiteThreatKind = .phishing
wait(for: [expectation], timeout: 100)
XCTAssertEqual(mockWebView.capturedJavaScriptString, "window.onChangePhishingStatus({\"phishingStatus\":true})")
diff --git a/Tests/PrivacyStatsTests/CurrentPackTests.swift b/Tests/PrivacyStatsTests/CurrentPackTests.swift
new file mode 100644
index 000000000..f22ed322d
--- /dev/null
+++ b/Tests/PrivacyStatsTests/CurrentPackTests.swift
@@ -0,0 +1,121 @@
+//
+// CurrentPackTests.swift
+//
+// Copyright © 2024 DuckDuckGo. All rights reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//
+
+import Combine
+import XCTest
+@testable import PrivacyStats
+
+final class CurrentPackTests: XCTestCase {
+ var currentPack: CurrentPack!
+
+ override func setUp() async throws {
+ currentPack = CurrentPack(pack: .init(timestamp: Date.currentPrivacyStatsPackTimestamp), commitDebounce: 10_000_000)
+ }
+
+ func testThatRecordBlockedTrackerUpdatesThePack() async {
+ await currentPack.recordBlockedTracker("A")
+ let companyA = await currentPack.pack.trackers["A"]
+ XCTAssertEqual(companyA, 1)
+ }
+
+ func testThatRecordBlockedTrackerTriggersCommitChangesEvent() async throws {
+ let packs = try await waitForCommitChangesEvents(for: 100_000_000) {
+ await currentPack.recordBlockedTracker("A")
+ }
+
+ let companyA = await currentPack.pack.trackers["A"]
+ XCTAssertEqual(companyA, 1)
+ XCTAssertEqual(packs.first?.trackers["A"], 1)
+ }
+
+ func testThatMultipleCallsToRecordBlockedTrackerOnlyTriggerOneCommitChangesEvent() async throws {
+ let packs = try await waitForCommitChangesEvents(for: 1000_000_000) {
+ await currentPack.recordBlockedTracker("A")
+ await currentPack.recordBlockedTracker("A")
+ await currentPack.recordBlockedTracker("A")
+ await currentPack.recordBlockedTracker("A")
+ await currentPack.recordBlockedTracker("A")
+ }
+
+ XCTAssertEqual(packs.count, 1)
+ XCTAssertEqual(packs.first?.trackers["A"], 5)
+ }
+
+ func testThatRecordBlockedTrackerCalledConcurrentlyForTheSameCompanyStoresAllCalls() async {
+ await withTaskGroup(of: Void.self) { group in
+ (0..<1000).forEach { _ in
+ group.addTask {
+ await self.currentPack.recordBlockedTracker("A")
+ }
+ }
+ }
+ let companyA = await currentPack.pack.trackers["A"]
+ XCTAssertEqual(companyA, 1000)
+ }
+
+ func testWhenCurrentPackIsOldThenRecordBlockedTrackerSendsCommitEventAndCreatesNewPack() async throws {
+ let oldTimestamp = Date.currentPrivacyStatsPackTimestamp.daysAgo(1)
+ let pack = PrivacyStatsPack(
+ timestamp: oldTimestamp,
+ trackers: ["A": 100, "B": 50, "C": 400]
+ )
+ currentPack = CurrentPack(pack: pack, commitDebounce: 10_000_000)
+
+ let packs = try await waitForCommitChangesEvents(for: 100_000_000) {
+ await currentPack.recordBlockedTracker("A")
+ }
+
+ XCTAssertEqual(packs.count, 2)
+ let oldPack = try XCTUnwrap(packs.first)
+ XCTAssertEqual(oldPack, pack)
+ let newPack = try XCTUnwrap(packs.last)
+ XCTAssertEqual(newPack, PrivacyStatsPack(timestamp: Date.currentPrivacyStatsPackTimestamp, trackers: ["A": 1]))
+ }
+
+ func testThatResetPackClearsAllRecordedTrackersAndSetsCurrentTimestamp() async {
+ let oldTimestamp = Date.currentPrivacyStatsPackTimestamp.daysAgo(1)
+ let pack = PrivacyStatsPack(
+ timestamp: oldTimestamp,
+ trackers: ["A": 100, "B": 50, "C": 400]
+ )
+ currentPack = CurrentPack(pack: pack, commitDebounce: 10_000_000)
+
+ await currentPack.resetPack()
+
+ let packAfterReset = await currentPack.pack
+ XCTAssertEqual(packAfterReset, PrivacyStatsPack(timestamp: Date.currentPrivacyStatsPackTimestamp, trackers: [:]))
+ }
+
+ // MARK: - Helpers
+
+ /**
+ * Sets up Combine subscription, then calls the provided block and then waits
+ * for the specific time before cancelling the subscription.
+ * Returns an array of values passed in the published events.
+ */
+ func waitForCommitChangesEvents(for nanoseconds: UInt64, _ block: () async -> Void) async throws -> [PrivacyStatsPack] {
+ var packs: [PrivacyStatsPack] = []
+ let cancellable = currentPack.commitChangesPublisher.sink { packs.append($0) }
+
+ await block()
+
+ try await Task.sleep(nanoseconds: nanoseconds)
+ cancellable.cancel()
+ return packs
+ }
+}
diff --git a/Tests/PrivacyStatsTests/PrivacyStatsTests.swift b/Tests/PrivacyStatsTests/PrivacyStatsTests.swift
new file mode 100644
index 000000000..fa05d8178
--- /dev/null
+++ b/Tests/PrivacyStatsTests/PrivacyStatsTests.swift
@@ -0,0 +1,317 @@
+//
+// PrivacyStatsTests.swift
+//
+// Copyright © 2024 DuckDuckGo. All rights reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//
+
+import Combine
+import Persistence
+import TrackerRadarKit
+import XCTest
+@testable import PrivacyStats
+
+final class PrivacyStatsTests: XCTestCase {
+ var databaseProvider: TestPrivacyStatsDatabaseProvider!
+ var privacyStats: PrivacyStats!
+
+ override func setUp() async throws {
+ databaseProvider = TestPrivacyStatsDatabaseProvider(databaseName: type(of: self).description())
+ privacyStats = PrivacyStats(databaseProvider: databaseProvider)
+ }
+
+ override func tearDown() async throws {
+ databaseProvider.tearDownDatabase()
+ }
+
+ // MARK: - initializer
+
+ func testThatOutdatedTrackerStatsAreDeletedUponInitialization() async throws {
+ try databaseProvider.addObjects { context in
+ let date = Date()
+
+ return [
+ DailyBlockedTrackersEntity.make(timestamp: date, companyName: "A", count: 1, context: context),
+ DailyBlockedTrackersEntity.make(timestamp: date.daysAgo(1), companyName: "A", count: 2, context: context),
+ DailyBlockedTrackersEntity.make(timestamp: date.daysAgo(6), companyName: "A", count: 7, context: context),
+ DailyBlockedTrackersEntity.make(timestamp: date.daysAgo(7), companyName: "A", count: 100, context: context),
+ DailyBlockedTrackersEntity.make(timestamp: date.daysAgo(8), companyName: "A", count: 100, context: context)
+ ]
+ }
+
+ // recreate database provider with existing location so that the existing database is persisted in the initializer
+ databaseProvider = TestPrivacyStatsDatabaseProvider(databaseName: type(of: self).description(), location: databaseProvider.location)
+ privacyStats = PrivacyStats(databaseProvider: databaseProvider)
+
+ let stats = await privacyStats.fetchPrivacyStats()
+ XCTAssertEqual(stats, ["A": 10])
+
+ let context = databaseProvider.database.makeContext(concurrencyType: .privateQueueConcurrencyType)
+ context.performAndWait {
+ do {
+ let allObjects = try context.fetch(DailyBlockedTrackersEntity.fetchRequest())
+ XCTAssertEqual(Set(allObjects.map(\.count)), [1, 2, 7])
+ } catch {
+ XCTFail("Context fetch should not fail")
+ }
+ }
+ }
+
+ // MARK: - fetchPrivacyStats
+
+ func testThatPrivacyStatsAreFetched() async throws {
+ let stats = await privacyStats.fetchPrivacyStats()
+ XCTAssertEqual(stats, [:])
+ }
+
+ func testThatFetchPrivacyStatsReturnsAllCompanies() async throws {
+ try databaseProvider.addObjects { context in
+ [
+ DailyBlockedTrackersEntity.make(companyName: "A", count: 10, context: context),
+ DailyBlockedTrackersEntity.make(companyName: "B", count: 5, context: context),
+ DailyBlockedTrackersEntity.make(companyName: "C", count: 13, context: context),
+ DailyBlockedTrackersEntity.make(companyName: "D", count: 42, context: context)
+ ]
+ }
+
+ let stats = await privacyStats.fetchPrivacyStats()
+ XCTAssertEqual(stats, ["A": 10, "B": 5, "C": 13, "D": 42])
+ }
+
+ func testThatFetchPrivacyStatsReturnsSumOfCompanyEntriesForPast7Days() async throws {
+ try databaseProvider.addObjects { context in
+ let date = Date()
+
+ return [
+ DailyBlockedTrackersEntity.make(timestamp: date, companyName: "A", count: 1, context: context),
+ DailyBlockedTrackersEntity.make(timestamp: date.daysAgo(1), companyName: "A", count: 2, context: context),
+ DailyBlockedTrackersEntity.make(timestamp: date.daysAgo(2), companyName: "A", count: 3, context: context),
+ DailyBlockedTrackersEntity.make(timestamp: date.daysAgo(3), companyName: "A", count: 4, context: context),
+ DailyBlockedTrackersEntity.make(timestamp: date.daysAgo(4), companyName: "A", count: 5, context: context),
+ DailyBlockedTrackersEntity.make(timestamp: date.daysAgo(5), companyName: "A", count: 6, context: context),
+ DailyBlockedTrackersEntity.make(timestamp: date.daysAgo(6), companyName: "A", count: 7, context: context)
+ ]
+ }
+
+ let stats = await privacyStats.fetchPrivacyStats()
+ XCTAssertEqual(stats, ["A": 28])
+ }
+
+ func testThatFetchPrivacyStatsDiscardsEntriesOlderThan7Days() async throws {
+ try databaseProvider.addObjects { context in
+ let date = Date()
+
+ return [
+ DailyBlockedTrackersEntity.make(timestamp: date, companyName: "A", count: 1, context: context),
+ DailyBlockedTrackersEntity.make(timestamp: date.daysAgo(1), companyName: "A", count: 2, context: context),
+ DailyBlockedTrackersEntity.make(timestamp: date.daysAgo(7), companyName: "A", count: 10, context: context),
+ DailyBlockedTrackersEntity.make(timestamp: date.daysAgo(10), companyName: "A", count: 10, context: context),
+ DailyBlockedTrackersEntity.make(timestamp: date.daysAgo(20), companyName: "A", count: 10, context: context),
+ DailyBlockedTrackersEntity.make(timestamp: date.daysAgo(500), companyName: "A", count: 10, context: context),
+ ]
+ }
+
+ let stats = await privacyStats.fetchPrivacyStats()
+ XCTAssertEqual(stats, ["A": 3])
+ }
+
+ // MARK: - recordBlockedTracker
+
+ func testThatCallingRecordBlockedTrackerCausesDatabaseSaveAfterDelay() async throws {
+ await privacyStats.recordBlockedTracker("A")
+
+ var stats = await privacyStats.fetchPrivacyStats()
+ XCTAssertEqual(stats, [:])
+
+ try await Task.sleep(nanoseconds: 1_500_000_000)
+
+ stats = await privacyStats.fetchPrivacyStats()
+ XCTAssertEqual(stats, ["A": 1])
+ }
+
+ func testThatStatsUpdatePublisherIsCalledAfterDatabaseSave() async throws {
+ await privacyStats.recordBlockedTracker("A")
+
+ await waitForStatsUpdateEvent()
+
+ var stats = await privacyStats.fetchPrivacyStats()
+ XCTAssertEqual(stats, ["A": 1])
+
+ await privacyStats.recordBlockedTracker("B")
+
+ await waitForStatsUpdateEvent()
+
+ stats = await privacyStats.fetchPrivacyStats()
+ XCTAssertEqual(stats, ["A": 1, "B": 1])
+ }
+
+ func testWhenMultipleTrackersAreReportedInQuickSuccessionThenOnlyOneStatsUpdateEventIsReported() async throws {
+ await withTaskGroup(of: Void.self) { group in
+ (0..<5).forEach { _ in
+ group.addTask {
+ await self.privacyStats.recordBlockedTracker("A")
+ }
+ }
+ (0..<10).forEach { _ in
+ group.addTask {
+ await self.privacyStats.recordBlockedTracker("B")
+ }
+ }
+ (0..<3).forEach { _ in
+ group.addTask {
+ await self.privacyStats.recordBlockedTracker("C")
+ }
+ }
+ }
+
+ // We have limited testing possibilities here, so let's just await the first stats update event
+ // and verify that all trackers are reported by privacy stats.
+ await waitForStatsUpdateEvent()
+
+ let stats = await privacyStats.fetchPrivacyStats()
+ XCTAssertEqual(stats, ["A": 5, "B": 10, "C": 3])
+ }
+
+ func testThatCallingRecordBlockedTrackerWithNextDayTimestampCausesDeletingOldEntriesFromDatabase() async throws {
+ try databaseProvider.addObjects { context in
+ let date = Date()
+ return [
+ DailyBlockedTrackersEntity.make(timestamp: date.daysAgo(1), companyName: "A", count: 2, context: context),
+ DailyBlockedTrackersEntity.make(timestamp: date.daysAgo(7), companyName: "A", count: 100, context: context),
+ ]
+ }
+
+ // recreate database provider with existing location so that the existing database is persisted in the initializer
+ databaseProvider = TestPrivacyStatsDatabaseProvider(databaseName: type(of: self).description(), location: databaseProvider.location)
+ privacyStats = PrivacyStats(databaseProvider: databaseProvider)
+
+ await privacyStats.recordBlockedTracker("A")
+
+ // No waiting here because the first commit event will be sent immediately from the actor when pack's timestamp changes.
+ // We aren't testing the debounced commit in this test case.
+
+ var stats = await privacyStats.fetchPrivacyStats()
+ XCTAssertEqual(stats, ["A": 2])
+
+ let context = databaseProvider.database.makeContext(concurrencyType: .privateQueueConcurrencyType)
+ context.performAndWait {
+ do {
+ let allObjects = try context.fetch(DailyBlockedTrackersEntity.fetchRequest())
+ XCTAssertEqual(Set(allObjects.map(\.count)), [2])
+ } catch {
+ XCTFail("Context fetch should not fail")
+ }
+ }
+
+ await waitForStatsUpdateEvent()
+ stats = await privacyStats.fetchPrivacyStats()
+ XCTAssertEqual(stats, ["A": 3])
+ }
+
+ // MARK: - clearPrivacyStats
+
+ func testThatClearPrivacyStatsTriggersUpdatesPublisher() async throws {
+ try await waitForStatsUpdateEvents(for: 1, count: 1) {
+ await privacyStats.clearPrivacyStats()
+ }
+ }
+
+ func testWhenClearPrivacyStatsIsCalledThenFetchPrivacyStatsIsEmpty() async throws {
+ try databaseProvider.addObjects { context in
+ let date = Date()
+
+ return [
+ DailyBlockedTrackersEntity.make(timestamp: date, companyName: "A", count: 1, context: context),
+ DailyBlockedTrackersEntity.make(timestamp: date.daysAgo(1), companyName: "A", count: 2, context: context),
+ DailyBlockedTrackersEntity.make(timestamp: date.daysAgo(7), companyName: "A", count: 10, context: context),
+ DailyBlockedTrackersEntity.make(timestamp: date.daysAgo(10), companyName: "A", count: 10, context: context),
+ DailyBlockedTrackersEntity.make(timestamp: date.daysAgo(20), companyName: "A", count: 10, context: context),
+ DailyBlockedTrackersEntity.make(timestamp: date.daysAgo(500), companyName: "A", count: 10, context: context),
+ ]
+ }
+
+ var stats = await privacyStats.fetchPrivacyStats()
+ XCTAssertFalse(stats.isEmpty)
+
+ await privacyStats.clearPrivacyStats()
+
+ stats = await privacyStats.fetchPrivacyStats()
+ XCTAssertTrue(stats.isEmpty)
+
+ let context = databaseProvider.database.makeContext(concurrencyType: .privateQueueConcurrencyType)
+ context.performAndWait {
+ do {
+ let allObjects = try context.fetch(DailyBlockedTrackersEntity.fetchRequest())
+ XCTAssertTrue(allObjects.isEmpty)
+ } catch {
+ XCTFail("fetch failed: \(error)")
+ }
+ }
+ }
+
+ // MARK: - handleAppTermination
+
+ func testThatHandleAppTerminationSavesCurrentPack() async throws {
+ let context = databaseProvider.database.makeContext(concurrencyType: .privateQueueConcurrencyType)
+
+ context.performAndWait {
+ do {
+ let allObjects = try context.fetch(DailyBlockedTrackersEntity.fetchRequest())
+ XCTAssertTrue(allObjects.isEmpty)
+ } catch {
+ XCTFail("fetch failed: \(error)")
+ }
+ }
+ await privacyStats.recordBlockedTracker("A")
+ await privacyStats.handleAppTermination()
+
+ context.performAndWait {
+ do {
+ let allObjects = try context.fetch(DailyBlockedTrackersEntity.fetchRequest())
+ XCTAssertEqual(allObjects.count, 1)
+ } catch {
+ XCTFail("fetch failed: \(error)")
+ }
+ }
+
+ await waitForStatsUpdateEvent()
+ let stats = await privacyStats.fetchPrivacyStats()
+ XCTAssertEqual(stats, ["A": 1])
+ }
+
+ // MARK: - Helpers
+
+ func waitForStatsUpdateEvent(file: StaticString = #file, line: UInt = #line) async {
+ let expectation = self.expectation(description: "statsUpdate")
+ let cancellable = privacyStats.statsUpdatePublisher.sink { expectation.fulfill() }
+ await fulfillment(of: [expectation], timeout: 2)
+ cancellable.cancel()
+ }
+
+ /**
+ * Sets up an expectation with the fulfillment count specified by `count` parameter,
+ * then sets up Combine subscription, then calls the provided block and waits
+ * for time specified by `duration` before cancelling the subscription.
+ */
+ func waitForStatsUpdateEvents(for duration: TimeInterval, count: Int, _ block: () async -> Void) async throws {
+ let expectation = self.expectation(description: "statsUpdate")
+ expectation.expectedFulfillmentCount = count
+ let cancellable = privacyStats.statsUpdatePublisher.sink { expectation.fulfill() }
+
+ await block()
+
+ await fulfillment(of: [expectation], timeout: duration)
+ cancellable.cancel()
+ }
+}
diff --git a/Tests/PrivacyStatsTests/PrivacyStatsUtilsTests.swift b/Tests/PrivacyStatsTests/PrivacyStatsUtilsTests.swift
new file mode 100644
index 000000000..ec0e7e606
--- /dev/null
+++ b/Tests/PrivacyStatsTests/PrivacyStatsUtilsTests.swift
@@ -0,0 +1,360 @@
+//
+// PrivacyStatsUtilsTests.swift
+//
+// Copyright © 2024 DuckDuckGo. All rights reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//
+
+import Persistence
+import XCTest
+@testable import PrivacyStats
+
+final class PrivacyStatsUtilsTests: XCTestCase {
+ var databaseProvider: TestPrivacyStatsDatabaseProvider!
+ var database: CoreDataDatabase!
+
+ override func setUp() async throws {
+ databaseProvider = TestPrivacyStatsDatabaseProvider(databaseName: type(of: self).description())
+ databaseProvider.initializeDatabase()
+ database = databaseProvider.database
+ }
+
+ override func tearDown() async throws {
+ databaseProvider.tearDownDatabase()
+ }
+
+ // MARK: - fetchOrInsertCurrentStats
+
+ func testWhenThereAreNoObjectsForCompaniesThenFetchOrInsertCurrentStatsInsertsNewObjects() {
+ let context = database.makeContext(concurrencyType: .privateQueueConcurrencyType)
+ context.performAndWait {
+ let currentPackTimestamp = Date.currentPrivacyStatsPackTimestamp
+ let companyNames: Set = ["A", "B", "C", "D"]
+
+ var returnedEntities: [DailyBlockedTrackersEntity] = []
+ do {
+ returnedEntities = try PrivacyStatsUtils.fetchOrInsertCurrentStats(for: companyNames, in: context)
+ } catch {
+ XCTFail("Should not throw")
+ }
+
+ let insertedEntities = context.insertedObjects.compactMap { $0 as? DailyBlockedTrackersEntity }
+
+ XCTAssertEqual(returnedEntities.count, 4)
+ XCTAssertEqual(insertedEntities.count, 4)
+ XCTAssertEqual(Set(insertedEntities.map(\.companyName)), companyNames)
+ XCTAssertEqual(Set(insertedEntities.map(\.companyName)), Set(returnedEntities.map(\.companyName)))
+
+ // All inserted entries have the same timestamp
+ XCTAssertEqual(Set(insertedEntities.map(\.timestamp)), [currentPackTimestamp])
+
+ // All inserted entries have the count of 0
+ XCTAssertEqual(Set(insertedEntities.map(\.count)), [0])
+ }
+ }
+
+ func testWhenThereAreExistingObjectsForCompaniesThenFetchOrInsertCurrentStatsReturnsThem() throws {
+ let date = Date()
+
+ try databaseProvider.addObjects { context in
+ return [
+ DailyBlockedTrackersEntity.make(timestamp: date, companyName: "A", count: 123, context: context),
+ DailyBlockedTrackersEntity.make(timestamp: date, companyName: "B", count: 4567, context: context)
+ ]
+ }
+
+ let context = database.makeContext(concurrencyType: .privateQueueConcurrencyType)
+ context.performAndWait {
+ let companyNames: Set = ["A", "B", "C", "D"]
+
+ var returnedEntities: [DailyBlockedTrackersEntity] = []
+ do {
+ returnedEntities = try PrivacyStatsUtils.fetchOrInsertCurrentStats(for: companyNames, in: context)
+ } catch {
+ XCTFail("Should not throw")
+ }
+
+ let insertedEntities = context.insertedObjects.compactMap { $0 as? DailyBlockedTrackersEntity }
+
+ XCTAssertEqual(returnedEntities.count, 4)
+ XCTAssertEqual(insertedEntities.count, 2)
+ XCTAssertEqual(Set(returnedEntities.map(\.companyName)), companyNames)
+ XCTAssertEqual(Set(insertedEntities.map(\.companyName)), ["C", "D"])
+
+ do {
+ let companyA = try XCTUnwrap(returnedEntities.first { $0.companyName == "A" })
+ let companyB = try XCTUnwrap(returnedEntities.first { $0.companyName == "B" })
+
+ XCTAssertEqual(companyA.count, 123)
+ XCTAssertEqual(companyB.count, 4567)
+ } catch {
+ XCTFail("Should find companies A and B")
+ }
+ }
+ }
+
+ // MARK: - loadCurrentDayStats
+
+ func testWhenThereAreNoObjectsInDatabaseThenLoadCurrentDayStatsIsEmpty() throws {
+ let context = database.makeContext(concurrencyType: .privateQueueConcurrencyType)
+
+ context.performAndWait {
+ do {
+ let currentDayStats = try PrivacyStatsUtils.loadCurrentDayStats(in: context)
+ XCTAssertTrue(currentDayStats.isEmpty)
+ } catch {
+ XCTFail("Should not throw")
+ }
+ }
+ }
+
+ func testWhenThereAreObjectsInDatabaseForPreviousDaysThenLoadCurrentDayStatsIsEmpty() throws {
+ let date = Date()
+
+ try databaseProvider.addObjects { context in
+ return [
+ DailyBlockedTrackersEntity.make(timestamp: date.daysAgo(1), companyName: "A", count: 123, context: context),
+ DailyBlockedTrackersEntity.make(timestamp: date.daysAgo(2), companyName: "B", count: 4567, context: context),
+ DailyBlockedTrackersEntity.make(timestamp: date.daysAgo(5), companyName: "C", count: 890, context: context)
+ ]
+ }
+
+ let context = database.makeContext(concurrencyType: .privateQueueConcurrencyType)
+
+ context.performAndWait {
+ do {
+ let currentDayStats = try PrivacyStatsUtils.loadCurrentDayStats(in: context)
+ XCTAssertTrue(currentDayStats.isEmpty)
+ } catch {
+ XCTFail("Should not throw")
+ }
+ }
+ }
+
+ func testThatObjectsWithZeroCountAreNotReportedByLoadCurrentDayStats() throws {
+ let date = Date()
+
+ try databaseProvider.addObjects { context in
+ return [
+ DailyBlockedTrackersEntity.make(timestamp: date, companyName: "A", count: 0, context: context),
+ DailyBlockedTrackersEntity.make(timestamp: date, companyName: "B", count: 0, context: context),
+ DailyBlockedTrackersEntity.make(timestamp: date, companyName: "C", count: 0, context: context)
+ ]
+ }
+
+ let context = database.makeContext(concurrencyType: .privateQueueConcurrencyType)
+
+ context.performAndWait {
+ do {
+ let currentDayStats = try PrivacyStatsUtils.loadCurrentDayStats(in: context)
+ XCTAssertTrue(currentDayStats.isEmpty)
+ } catch {
+ XCTFail("Should not throw")
+ }
+ }
+ }
+
+ func testThatObjectsWithNonZeroCountAreReportedByLoadCurrentDayStats() throws {
+ let date = Date()
+
+ try databaseProvider.addObjects { context in
+ return [
+ DailyBlockedTrackersEntity.make(timestamp: date, companyName: "A", count: 150, context: context),
+ DailyBlockedTrackersEntity.make(timestamp: date, companyName: "B", count: 400, context: context),
+ DailyBlockedTrackersEntity.make(timestamp: date, companyName: "C", count: 84, context: context),
+ DailyBlockedTrackersEntity.make(timestamp: date, companyName: "D", count: 5, context: context)
+ ]
+ }
+
+ let context = database.makeContext(concurrencyType: .privateQueueConcurrencyType)
+
+ context.performAndWait {
+ do {
+ let currentDayStats = try PrivacyStatsUtils.loadCurrentDayStats(in: context)
+ XCTAssertEqual(currentDayStats, ["A": 150, "B": 400, "C": 84, "D": 5])
+ } catch {
+ XCTFail("Should not throw")
+ }
+ }
+ }
+
+ // MARK: - load7DayStats
+
+ func testWhenThereAreNoObjectsInDatabaseThenLoad7DayStatsIsEmpty() throws {
+ let context = database.makeContext(concurrencyType: .privateQueueConcurrencyType)
+
+ context.performAndWait {
+ do {
+ let stats = try PrivacyStatsUtils.load7DayStats(in: context)
+ XCTAssertTrue(stats.isEmpty)
+ } catch {
+ XCTFail("Should not throw")
+ }
+ }
+ }
+
+ func testWhenThereAreObjectsInDatabaseFrom7DaysAgoOrMoreThenLoad7DayStatsIsEmpty() throws {
+ let date = Date()
+
+ try databaseProvider.addObjects { context in
+ return [
+ DailyBlockedTrackersEntity.make(timestamp: date.daysAgo(10), companyName: "A", count: 123, context: context),
+ DailyBlockedTrackersEntity.make(timestamp: date.daysAgo(20), companyName: "B", count: 4567, context: context),
+ DailyBlockedTrackersEntity.make(timestamp: date.daysAgo(7), companyName: "C", count: 890, context: context)
+ ]
+ }
+
+ let context = database.makeContext(concurrencyType: .privateQueueConcurrencyType)
+
+ context.performAndWait {
+ do {
+ let stats = try PrivacyStatsUtils.load7DayStats(in: context)
+ XCTAssertTrue(stats.isEmpty)
+ } catch {
+ XCTFail("Should not throw")
+ }
+ }
+ }
+
+ func testThatObjectsWithZeroCountAreNotReportedByLoad7DayStats() throws {
+ let date = Date()
+
+ try databaseProvider.addObjects { context in
+ return [
+ DailyBlockedTrackersEntity.make(timestamp: date, companyName: "A", count: 0, context: context),
+ DailyBlockedTrackersEntity.make(timestamp: date.daysAgo(4), companyName: "B", count: 0, context: context),
+ DailyBlockedTrackersEntity.make(timestamp: date.daysAgo(6), companyName: "C", count: 0, context: context)
+ ]
+ }
+
+ let context = database.makeContext(concurrencyType: .privateQueueConcurrencyType)
+
+ context.performAndWait {
+ do {
+ let stats = try PrivacyStatsUtils.load7DayStats(in: context)
+ XCTAssertTrue(stats.isEmpty)
+ } catch {
+ XCTFail("Should not throw")
+ }
+ }
+ }
+
+ func testThatObjectsWithNonZeroCountAreReportedByLoad7DayStats() throws {
+ let date = Date()
+
+ try databaseProvider.addObjects { context in
+ return [
+ DailyBlockedTrackersEntity.make(timestamp: date, companyName: "A", count: 150, context: context),
+ DailyBlockedTrackersEntity.make(timestamp: date.daysAgo(1), companyName: "B", count: 400, context: context),
+ DailyBlockedTrackersEntity.make(timestamp: date.daysAgo(2), companyName: "C", count: 84, context: context),
+ DailyBlockedTrackersEntity.make(timestamp: date.daysAgo(6), companyName: "D", count: 5, context: context)
+ ]
+ }
+
+ let context = database.makeContext(concurrencyType: .privateQueueConcurrencyType)
+
+ context.performAndWait {
+ do {
+ let stats = try PrivacyStatsUtils.load7DayStats(in: context)
+ XCTAssertEqual(stats, ["A": 150, "B": 400, "C": 84, "D": 5])
+ } catch {
+ XCTFail("Should not throw")
+ }
+ }
+ }
+
+ // MARK: - deleteOutdatedPacks
+
+ func testWhenDeleteOutdatedPacksIsCalledThenObjectsFrom7DaysAgoOrMoreAreDeleted() throws {
+ let date = Date()
+
+ try databaseProvider.addObjects { context in
+ return [
+ DailyBlockedTrackersEntity.make(timestamp: date, companyName: "C", count: 1, context: context),
+ DailyBlockedTrackersEntity.make(timestamp: date.daysAgo(4), companyName: "C", count: 2, context: context),
+ DailyBlockedTrackersEntity.make(timestamp: date.daysAgo(6), companyName: "C", count: 3, context: context),
+ DailyBlockedTrackersEntity.make(timestamp: date.daysAgo(7), companyName: "C", count: 4, context: context),
+ DailyBlockedTrackersEntity.make(timestamp: date.daysAgo(8), companyName: "C", count: 5, context: context),
+ DailyBlockedTrackersEntity.make(timestamp: date.daysAgo(100), companyName: "C", count: 6, context: context)
+ ]
+ }
+
+ let context = database.makeContext(concurrencyType: .privateQueueConcurrencyType)
+
+ context.performAndWait {
+ do {
+ try PrivacyStatsUtils.deleteOutdatedPacks(in: context)
+
+ let allObjects = try context.fetch(DailyBlockedTrackersEntity.fetchRequest())
+ XCTAssertEqual(Set(allObjects.map(\.count)), [1, 2, 3])
+ } catch {
+ XCTFail("Should not throw")
+ }
+ }
+ }
+
+ func testWhenObjectsFrom7DaysAgoOrMoreAreNotPresentThenDeleteOutdatedPacksHasNoEffect() throws {
+ let date = Date()
+
+ try databaseProvider.addObjects { context in
+ return [
+ DailyBlockedTrackersEntity.make(timestamp: date, companyName: "C", count: 1, context: context),
+ DailyBlockedTrackersEntity.make(timestamp: date.daysAgo(4), companyName: "C", count: 2, context: context),
+ DailyBlockedTrackersEntity.make(timestamp: date.daysAgo(6), companyName: "C", count: 3, context: context)
+ ]
+ }
+
+ let context = database.makeContext(concurrencyType: .privateQueueConcurrencyType)
+
+ context.performAndWait {
+ do {
+ try PrivacyStatsUtils.deleteOutdatedPacks(in: context)
+
+ let allObjects = try context.fetch(DailyBlockedTrackersEntity.fetchRequest())
+ XCTAssertEqual(allObjects.count, 3)
+ } catch {
+ XCTFail("Should not throw")
+ }
+ }
+ }
+
+ // MARK: - deleteAllStats
+
+ func testThatDeleteAllStatsRemovesAllDatabaseObjects() throws {
+ let date = Date()
+
+ try databaseProvider.addObjects { context in
+ return [
+ DailyBlockedTrackersEntity.make(timestamp: date, companyName: "C", count: 1, context: context),
+ DailyBlockedTrackersEntity.make(timestamp: date.daysAgo(4), companyName: "C", count: 2, context: context),
+ DailyBlockedTrackersEntity.make(timestamp: date.daysAgo(6), companyName: "C", count: 3, context: context),
+ DailyBlockedTrackersEntity.make(timestamp: date.daysAgo(60), companyName: "C", count: 3, context: context),
+ DailyBlockedTrackersEntity.make(timestamp: date.daysAgo(600), companyName: "C", count: 3, context: context)
+ ]
+ }
+
+ let context = database.makeContext(concurrencyType: .privateQueueConcurrencyType)
+
+ context.performAndWait {
+ do {
+ try PrivacyStatsUtils.deleteAllStats(in: context)
+
+ let allObjects = try context.fetch(DailyBlockedTrackersEntity.fetchRequest())
+ XCTAssertTrue(allObjects.isEmpty)
+ } catch {
+ XCTFail("Should not throw")
+ }
+ }
+ }
+}
diff --git a/Tests/PrivacyStatsTests/TestPrivacyStatsDatabaseProvider.swift b/Tests/PrivacyStatsTests/TestPrivacyStatsDatabaseProvider.swift
new file mode 100644
index 000000000..2cb210f0b
--- /dev/null
+++ b/Tests/PrivacyStatsTests/TestPrivacyStatsDatabaseProvider.swift
@@ -0,0 +1,65 @@
+//
+// TestPrivacyStatsDatabaseProvider.swift
+//
+// Copyright © 2024 DuckDuckGo. All rights reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//
+
+import Persistence
+import XCTest
+@testable import PrivacyStats
+
+final class TestPrivacyStatsDatabaseProvider: PrivacyStatsDatabaseProviding {
+ let databaseName: String
+ var database: CoreDataDatabase!
+ var location: URL!
+
+ init(databaseName: String) {
+ self.databaseName = databaseName
+ }
+
+ init(databaseName: String, location: URL) {
+ self.databaseName = databaseName
+ self.location = location
+ }
+
+ @discardableResult
+ func initializeDatabase() -> CoreDataDatabase {
+ if location == nil {
+ location = FileManager.default.temporaryDirectory.appendingPathComponent(UUID().uuidString)
+ }
+ let model = CoreDataDatabase.loadModel(from: PrivacyStats.bundle, named: "PrivacyStats")!
+ database = CoreDataDatabase(name: databaseName, containerLocation: location, model: model)
+ database.loadStore()
+ return database
+ }
+
+ func tearDownDatabase() {
+ try? database.tearDown(deleteStores: true)
+ database = nil
+ try? FileManager.default.removeItem(at: location)
+ }
+
+ func addObjects(_ objects: (NSManagedObjectContext) -> [DailyBlockedTrackersEntity], file: StaticString = #file, line: UInt = #line) throws {
+ let context = database.makeContext(concurrencyType: .privateQueueConcurrencyType)
+ context.performAndWait {
+ _ = objects(context)
+ do {
+ try context.save()
+ } catch {
+ XCTFail("save failed: \(error)", file: file, line: line)
+ }
+ }
+ }
+}
diff --git a/Tests/RemoteMessagingTests/Mappers/DefaultRemoteMessagingSurveyURLBuilderTests.swift b/Tests/RemoteMessagingTests/Mappers/DefaultRemoteMessagingSurveyURLBuilderTests.swift
new file mode 100644
index 000000000..c6a490eae
--- /dev/null
+++ b/Tests/RemoteMessagingTests/Mappers/DefaultRemoteMessagingSurveyURLBuilderTests.swift
@@ -0,0 +1,127 @@
+//
+// DefaultRemoteMessagingSurveyURLBuilderTests.swift
+//
+// Copyright © 2024 DuckDuckGo. All rights reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//
+
+import XCTest
+import BrowserServicesKitTestsUtils
+import RemoteMessagingTestsUtils
+@testable import Subscription
+@testable import RemoteMessaging
+
+class DefaultRemoteMessagingSurveyURLBuilderTests: XCTestCase {
+
+ func testAddingATBParameter() {
+ let builder = buildRemoteMessagingSurveyURLBuilder(atb: "v456-7")
+ let baseURL = URL(string: "https://duckduckgo.com")!
+ let finalURL = builder.add(parameters: [.atb], to: baseURL)
+
+ XCTAssertEqual(finalURL.absoluteString, "https://duckduckgo.com?atb=v456-7")
+ }
+
+ func testAddingATBVariantParameter() {
+ let builder = buildRemoteMessagingSurveyURLBuilder(variant: "test-variant")
+ let baseURL = URL(string: "https://duckduckgo.com")!
+ let finalURL = builder.add(parameters: [.atbVariant], to: baseURL)
+
+ XCTAssertEqual(finalURL.absoluteString, "https://duckduckgo.com?var=test-variant")
+ }
+
+ func testAddingLocaleParameter() {
+ let builder = buildRemoteMessagingSurveyURLBuilder(locale: Locale(identifier: "en_NZ"))
+ let baseURL = URL(string: "https://duckduckgo.com")!
+ let finalURL = builder.add(parameters: [.locale], to: baseURL)
+
+ XCTAssertEqual(finalURL.absoluteString, "https://duckduckgo.com?locale=en-NZ")
+ }
+
+ func testAddingPrivacyProParameters() {
+ let builder = buildRemoteMessagingSurveyURLBuilder()
+ let baseURL = URL(string: "https://duckduckgo.com")!
+ let finalURL = builder.add(parameters: [.privacyProStatus, .privacyProPlatform, .privacyProPlatform], to: baseURL)
+
+ XCTAssertEqual(finalURL.absoluteString, "https://duckduckgo.com?ppro_status=auto_renewable&ppro_platform=apple&ppro_platform=apple")
+ }
+
+ func testAddingVPNUsageParameters() {
+ let builder = buildRemoteMessagingSurveyURLBuilder(vpnDaysSinceActivation: 10, vpnDaysSinceLastActive: 5)
+ let baseURL = URL(string: "https://duckduckgo.com")!
+ let finalURL = builder.add(parameters: [.vpnFirstUsed, .vpnLastUsed], to: baseURL)
+
+ XCTAssertEqual(finalURL.absoluteString, "https://duckduckgo.com?vpn_first_used=10&vpn_last_used=5")
+ }
+
+ func testAddingParametersToURLThatAlreadyHasThem() {
+ let builder = buildRemoteMessagingSurveyURLBuilder(vpnDaysSinceActivation: 10, vpnDaysSinceLastActive: 5)
+ let baseURL = URL(string: "https://duckduckgo.com?param=test")!
+ let finalURL = builder.add(parameters: [.vpnFirstUsed, .vpnLastUsed], to: baseURL)
+
+ XCTAssertEqual(finalURL.absoluteString, "https://duckduckgo.com?param=test&vpn_first_used=10&vpn_last_used=5")
+ }
+
+ private func buildRemoteMessagingSurveyURLBuilder(
+ atb: String = "v123-4",
+ variant: String = "var",
+ vpnDaysSinceActivation: Int = 2,
+ vpnDaysSinceLastActive: Int = 1,
+ locale: Locale = Locale(identifier: "en_US")
+ ) -> DefaultRemoteMessagingSurveyURLBuilder {
+
+ let mockStatisticsStore = MockStatisticsStore()
+ mockStatisticsStore.atb = atb
+ mockStatisticsStore.variant = variant
+
+ let vpnActivationDateStore = MockVPNActivationDateStore(
+ daysSinceActivation: vpnDaysSinceActivation,
+ daysSinceLastActive: vpnDaysSinceLastActive
+ )
+
+ let subscription = DDGSubscription(productId: "product-id",
+ name: "product-name",
+ billingPeriod: .monthly,
+ startedAt: Date(timeIntervalSince1970: 1000),
+ expiresOrRenewsAt: Date(timeIntervalSince1970: 2000),
+ platform: .apple,
+ status: .autoRenewable)
+
+ return DefaultRemoteMessagingSurveyURLBuilder(
+ statisticsStore: mockStatisticsStore,
+ vpnActivationDateStore: vpnActivationDateStore,
+ subscription: subscription,
+ localeIdentifier: locale.identifier)
+ }
+
+}
+
+private class MockVPNActivationDateStore: VPNActivationDateProviding {
+
+ var _daysSinceActivation: Int
+ var _daysSinceLastActive: Int
+
+ init(daysSinceActivation: Int, daysSinceLastActive: Int) {
+ self._daysSinceActivation = daysSinceActivation
+ self._daysSinceLastActive = daysSinceLastActive
+ }
+
+ func daysSinceActivation() -> Int? {
+ return _daysSinceActivation
+ }
+
+ func daysSinceLastActive() -> Int? {
+ return _daysSinceLastActive
+ }
+
+}
diff --git a/Tests/SpecialErrorPagesTests/SpecialErrorPagesTest.swift b/Tests/SpecialErrorPagesTests/SpecialErrorPagesTests.swift
similarity index 96%
rename from Tests/SpecialErrorPagesTests/SpecialErrorPagesTest.swift
rename to Tests/SpecialErrorPagesTests/SpecialErrorPagesTests.swift
index fa0ebf895..4eec5c739 100644
--- a/Tests/SpecialErrorPagesTests/SpecialErrorPagesTest.swift
+++ b/Tests/SpecialErrorPagesTests/SpecialErrorPagesTests.swift
@@ -1,5 +1,5 @@
//
-// SpecialErrorPagesTest.swift
+// SpecialErrorPagesTests.swift
//
// Copyright © 2023 DuckDuckGo. All rights reserved.
//
@@ -108,7 +108,7 @@ final class SpecialErrorPageUserScriptTests: XCTestCase {
@MainActor
func test_WhenHandlerForInitialSetUpCalled_AndIsEnabledTrue_ThenRightParameterReturned() async {
// GIVEN
- let expectedData = SpecialErrorData(kind: .ssl, errorType: "some error type", domain: "someDomain")
+ let expectedData = SpecialErrorData.ssl(type: .invalid, domain: "someDomain", eTldPlus1: nil)
var encodable: Encodable?
userScript.isEnabled = true
delegate.errorData = expectedData
@@ -191,11 +191,11 @@ class CapturingSpecialErrorPageUserScriptDelegate: SpecialErrorPageUserScriptDel
var visitSiteCalled = false
var advancedInfoPresentedCalled = false
- func leaveSite() {
+ func leaveSiteAction() {
leaveSiteCalled = true
}
- func visitSite() {
+ func visitSiteAction() {
visitSiteCalled = true
}
diff --git a/Tests/SubscriptionTests/Flows/Models/SubscriptionOptionsTests.swift b/Tests/SubscriptionTests/Flows/Models/SubscriptionOptionsTests.swift
index 5dfa14980..3fb695733 100644
--- a/Tests/SubscriptionTests/Flows/Models/SubscriptionOptionsTests.swift
+++ b/Tests/SubscriptionTests/Flows/Models/SubscriptionOptionsTests.swift
@@ -23,7 +23,7 @@ import SubscriptionTestingUtilities
final class SubscriptionOptionsTests: XCTestCase {
func testEncoding() throws {
- let subscriptionOptions = SubscriptionOptions(platform: "macos",
+ let subscriptionOptions = SubscriptionOptions(platform: .macos,
options: [
SubscriptionOption(id: "1",
cost: SubscriptionOptionCost(displayPrice: "9 USD", recurrence: "monthly")),
@@ -31,9 +31,9 @@ final class SubscriptionOptionsTests: XCTestCase {
cost: SubscriptionOptionCost(displayPrice: "99 USD", recurrence: "yearly"))
],
features: [
- SubscriptionFeature(name: "vpn"),
- SubscriptionFeature(name: "personal-information-removal"),
- SubscriptionFeature(name: "identity-theft-restoration")
+ SubscriptionFeature(name: .networkProtection),
+ SubscriptionFeature(name: .dataBrokerProtection),
+ SubscriptionFeature(name: .identityTheftRestoration)
])
let jsonEncoder = JSONEncoder()
@@ -45,13 +45,13 @@ final class SubscriptionOptionsTests: XCTestCase {
{
"features" : [
{
- "name" : "vpn"
+ "name" : "Network Protection"
},
{
- "name" : "personal-information-removal"
+ "name" : "Data Broker Protection"
},
{
- "name" : "identity-theft-restoration"
+ "name" : "Identity Theft Restoration"
}
],
"options" : [
@@ -87,12 +87,12 @@ final class SubscriptionOptionsTests: XCTestCase {
}
func testSubscriptionFeatureEncoding() throws {
- let subscriptionFeature = SubscriptionFeature(name: "identity-theft-restoration")
+ let subscriptionFeature = SubscriptionFeature(name: .identityTheftRestoration)
let data = try? JSONEncoder().encode(subscriptionFeature)
let subscriptionFeatureString = String(data: data!, encoding: .utf8)!
- XCTAssertEqual(subscriptionFeatureString, "{\"name\":\"identity-theft-restoration\"}")
+ XCTAssertEqual(subscriptionFeatureString, "{\"name\":\"Identity Theft Restoration\"}")
}
func testEmptySubscriptionOptions() throws {
@@ -105,8 +105,8 @@ final class SubscriptionOptionsTests: XCTestCase {
platform = .macos
#endif
- XCTAssertEqual(empty.platform, platform.rawValue)
+ XCTAssertEqual(empty.platform, platform)
XCTAssertTrue(empty.options.isEmpty)
- XCTAssertEqual(empty.features.count, SubscriptionFeatureName.allCases.count)
+ XCTAssertEqual(empty.features.count, 3)
}
}
diff --git a/Tests/SubscriptionTests/Flows/StripePurchaseFlowTests.swift b/Tests/SubscriptionTests/Flows/StripePurchaseFlowTests.swift
index 1752e3d15..e397805db 100644
--- a/Tests/SubscriptionTests/Flows/StripePurchaseFlowTests.swift
+++ b/Tests/SubscriptionTests/Flows/StripePurchaseFlowTests.swift
@@ -67,12 +67,14 @@ final class StripePurchaseFlowTests: XCTestCase {
// Then
switch result {
case .success(let success):
- XCTAssertEqual(success.platform, SubscriptionPlatformName.stripe.rawValue)
+ XCTAssertEqual(success.platform, SubscriptionPlatformName.stripe)
XCTAssertEqual(success.options.count, SubscriptionMockFactory.productsItems.count)
- XCTAssertEqual(success.features.count, SubscriptionFeatureName.allCases.count)
+ XCTAssertEqual(success.features.count, 3)
+ let allFeatures = [Entitlement.ProductName.networkProtection, Entitlement.ProductName.dataBrokerProtection, Entitlement.ProductName.identityTheftRestoration]
let allNames = success.features.compactMap({ feature in feature.name})
- for name in SubscriptionFeatureName.allCases {
- XCTAssertTrue(allNames.contains(name.rawValue))
+
+ for feature in allFeatures {
+ XCTAssertTrue(allNames.contains(feature))
}
case .failure(let error):
XCTFail("Unexpected failure: \(error)")
diff --git a/Tests/SubscriptionTests/Managers/SubscriptionManagerTests.swift b/Tests/SubscriptionTests/Managers/SubscriptionManagerTests.swift
index 9054194da..26ce6d89c 100644
--- a/Tests/SubscriptionTests/Managers/SubscriptionManagerTests.swift
+++ b/Tests/SubscriptionTests/Managers/SubscriptionManagerTests.swift
@@ -34,7 +34,9 @@ final class SubscriptionManagerTests: XCTestCase {
var accountManager: AccountManagerMock!
var subscriptionService: SubscriptionEndpointServiceMock!
var authService: AuthEndpointServiceMock!
+ var subscriptionFeatureMappingCache: SubscriptionFeatureMappingCacheMock!
var subscriptionEnvironment: SubscriptionEnvironment!
+ var subscriptionFeatureFlagger: FeatureFlaggerMapping!
var subscriptionManager: SubscriptionManager!
@@ -43,14 +45,18 @@ final class SubscriptionManagerTests: XCTestCase {
accountManager = AccountManagerMock()
subscriptionService = SubscriptionEndpointServiceMock()
authService = AuthEndpointServiceMock()
+ subscriptionFeatureMappingCache = SubscriptionFeatureMappingCacheMock()
subscriptionEnvironment = SubscriptionEnvironment(serviceEnvironment: .production,
purchasePlatform: .appStore)
+ subscriptionFeatureFlagger = FeatureFlaggerMapping(mapping: { $0.defaultState })
subscriptionManager = DefaultSubscriptionManager(storePurchaseManager: storePurchaseManager,
accountManager: accountManager,
subscriptionEndpointService: subscriptionService,
authEndpointService: authService,
- subscriptionEnvironment: subscriptionEnvironment)
+ subscriptionFeatureMappingCache: subscriptionFeatureMappingCache,
+ subscriptionEnvironment: subscriptionEnvironment,
+ subscriptionFeatureFlagger: subscriptionFeatureFlagger)
}
@@ -202,7 +208,9 @@ final class SubscriptionManagerTests: XCTestCase {
accountManager: accountManager,
subscriptionEndpointService: subscriptionService,
authEndpointService: authService,
- subscriptionEnvironment: productionEnvironment)
+ subscriptionFeatureMappingCache: subscriptionFeatureMappingCache,
+ subscriptionEnvironment: productionEnvironment,
+ subscriptionFeatureFlagger: subscriptionFeatureFlagger)
// When
let productionPurchaseURL = productionSubscriptionManager.url(for: .purchase)
@@ -219,7 +227,9 @@ final class SubscriptionManagerTests: XCTestCase {
accountManager: accountManager,
subscriptionEndpointService: subscriptionService,
authEndpointService: authService,
- subscriptionEnvironment: stagingEnvironment)
+ subscriptionFeatureMappingCache: subscriptionFeatureMappingCache,
+ subscriptionEnvironment: stagingEnvironment,
+ subscriptionFeatureFlagger: subscriptionFeatureFlagger)
// When
let stagingPurchaseURL = stagingSubscriptionManager.url(for: .purchase)
diff --git a/Tests/SubscriptionTests/SubscriptionCookie/SubscriptionCookieManagerTests.swift b/Tests/SubscriptionTests/SubscriptionCookie/SubscriptionCookieManagerTests.swift
index 07fb12bd4..2a6a9d3d8 100644
--- a/Tests/SubscriptionTests/SubscriptionCookie/SubscriptionCookieManagerTests.swift
+++ b/Tests/SubscriptionTests/SubscriptionCookie/SubscriptionCookieManagerTests.swift
@@ -33,6 +33,7 @@ final class SubscriptionCookieManagerTests: XCTestCase {
var authService: AuthEndpointServiceMock!
var storePurchaseManager: StorePurchaseManagerMock!
var subscriptionEnvironment: SubscriptionEnvironment!
+ var subscriptionFeatureMappingCache: SubscriptionFeatureMappingCacheMock!
var subscriptionManager: SubscriptionManagerMock!
var cookieStore: HTTPCookieStore!
@@ -45,13 +46,15 @@ final class SubscriptionCookieManagerTests: XCTestCase {
storePurchaseManager = StorePurchaseManagerMock()
subscriptionEnvironment = SubscriptionEnvironment(serviceEnvironment: .production,
purchasePlatform: .appStore)
+ subscriptionFeatureMappingCache = SubscriptionFeatureMappingCacheMock()
subscriptionManager = SubscriptionManagerMock(accountManager: accountManager,
subscriptionEndpointService: subscriptionService,
authEndpointService: authService,
storePurchaseManager: storePurchaseManager,
currentEnvironment: subscriptionEnvironment,
- canPurchase: true)
+ canPurchase: true,
+ subscriptionFeatureMappingCache: subscriptionFeatureMappingCache)
cookieStore = MockHTTPCookieStore()
subscriptionCookieManager = SubscriptionCookieManager(subscriptionManager: subscriptionManager,