diff --git a/Package.resolved b/Package.resolved index 4461ccf24..2b5b815cc 100644 --- a/Package.resolved +++ b/Package.resolved @@ -63,6 +63,24 @@ "version" : "3.0.0" } }, + { + "identity" : "swift-clocks", + "kind" : "remoteSourceControl", + "location" : "https://github.com/pointfreeco/swift-clocks.git", + "state" : { + "revision" : "b9b24b69e2adda099a1fa381cda1eeec272d5b53", + "version" : "1.0.5" + } + }, + { + "identity" : "swift-concurrency-extras", + "kind" : "remoteSourceControl", + "location" : "https://github.com/pointfreeco/swift-concurrency-extras", + "state" : { + "revision" : "163409ef7dae9d960b87f34b51587b6609a76c1f", + "version" : "1.3.0" + } + }, { "identity" : "swifter", "kind" : "remoteSourceControl", @@ -89,6 +107,15 @@ "revision" : "5de0a610a7927b638a5fd463a53032c9934a2c3b", "version" : "3.0.0" } + }, + { + "identity" : "xctest-dynamic-overlay", + "kind" : "remoteSourceControl", + "location" : "https://github.com/pointfreeco/xctest-dynamic-overlay", + "state" : { + "revision" : "a3f634d1a409c7979cabc0a71b3f26ffa9fc8af1", + "version" : "1.4.3" + } } ], "version" : 2 diff --git a/Package.swift b/Package.swift index c074ced7e..d8b3e2d0c 100644 --- a/Package.swift +++ b/Package.swift @@ -57,7 +57,8 @@ let package = Package( .package(url: "https://github.com/duckduckgo/privacy-dashboard", exact: "7.2.0"), .package(url: "https://github.com/httpswift/swifter.git", exact: "1.5.0"), .package(url: "https://github.com/duckduckgo/bloom_cpp.git", exact: "3.0.0"), - .package(url: "https://github.com/1024jp/GzipSwift.git", exact: "6.0.1") + .package(url: "https://github.com/1024jp/GzipSwift.git", exact: "6.0.1"), + .package(url: "https://github.com/pointfreeco/swift-clocks.git", exact: "1.0.5"), ], targets: [ .target( @@ -644,7 +645,8 @@ let package = Package( .testTarget( name: "DuckPlayerTests", dependencies: [ - "DuckPlayer" + "DuckPlayer", + "BrowserServicesKitTestsUtils", ] ), @@ -653,6 +655,7 @@ let package = Package( dependencies: [ "TestUtils", "MaliciousSiteProtection", + .product(name: "Clocks", package: "swift-clocks"), ], resources: [ .copy("Resources/phishingHashPrefixes.json"), diff --git a/Sources/Common/Concurrency/TaskExtension.swift b/Sources/Common/Concurrency/TaskExtension.swift index a407e4406..65d974a36 100644 --- a/Sources/Common/Concurrency/TaskExtension.swift +++ b/Sources/Common/Concurrency/TaskExtension.swift @@ -18,51 +18,72 @@ import Foundation +public struct Sleeper { + + public static let `default` = Sleeper(sleep: { + try await Task.sleep(interval: $0) + }) + + private let sleep: (TimeInterval) async throws -> Void + + public init(sleep: @escaping (TimeInterval) async throws -> Void) { + self.sleep = sleep + } + + @available(macOS 13.0, iOS 16.0, *) + public init(clock: any Clock) { + self.sleep = { interval in + try await clock.sleep(for: .nanoseconds(UInt64(interval * Double(NSEC_PER_SEC)))) + } + } + + public func sleep(for interval: TimeInterval) async throws { + try await sleep(interval) + } + +} + +public func performPeriodicJob(withDelay delay: TimeInterval? = nil, + interval: TimeInterval, + sleeper: Sleeper = .default, + operation: @escaping @Sendable () async throws -> Void, + cancellationHandler: (@Sendable () async -> Void)? = nil) async throws -> Never { + + do { + if let delay { + try await sleeper.sleep(for: delay) + } + + repeat { + try await operation() + + try await sleeper.sleep(for: interval) + } while true + } catch let error as CancellationError { + await cancellationHandler?() + throw error + } +} + public extension Task where Success == Never, Failure == Error { static func periodic(delay: TimeInterval? = nil, interval: TimeInterval, + sleeper: Sleeper = .default, operation: @escaping @Sendable () async -> Void, cancellationHandler: (@Sendable () async -> Void)? = nil) -> Task { - Task { - do { - if let delay { - try await Task.sleep(interval: delay) - } - - repeat { - await operation() - - try await Task.sleep(interval: interval) - } while true - } catch { - await cancellationHandler?() - throw error - } - } + return periodic(delay: delay, interval: interval, sleeper: sleeper, operation: { await operation() } as @Sendable () async throws -> Void, cancellationHandler: cancellationHandler) } static func periodic(delay: TimeInterval? = nil, interval: TimeInterval, + sleeper: Sleeper = .default, operation: @escaping @Sendable () async throws -> Void, cancellationHandler: (@Sendable () async -> Void)? = nil) -> Task { Task { - do { - if let delay { - try await Task.sleep(interval: delay) - } - - repeat { - try await operation() - - try await Task.sleep(interval: interval) - } while true - } catch { - await cancellationHandler?() - throw error - } + try await performPeriodicJob(withDelay: delay, interval: interval, sleeper: sleeper, operation: operation, cancellationHandler: cancellationHandler) } } } diff --git a/Sources/Common/Extensions/HashExtension.swift b/Sources/Common/Extensions/HashExtension.swift index b6752cf57..13095cf63 100644 --- a/Sources/Common/Extensions/HashExtension.swift +++ b/Sources/Common/Extensions/HashExtension.swift @@ -42,8 +42,13 @@ extension Data { extension String { public var sha1: String { - let dataBytes = data(using: .utf8)! - return dataBytes.sha1 + let result = utf8data.sha1 + return result + } + + public var sha256: String { + let result = utf8data.sha256 + return result } } diff --git a/Sources/Common/Extensions/StringExtension.swift b/Sources/Common/Extensions/StringExtension.swift index 09050cfe2..9282a43b4 100644 --- a/Sources/Common/Extensions/StringExtension.swift +++ b/Sources/Common/Extensions/StringExtension.swift @@ -394,9 +394,9 @@ public extension String { // MARK: Regex - func matches(_ regex: NSRegularExpression) -> Bool { - let matches = regex.matches(in: self, options: .anchored, range: self.fullRange) - return matches.count == 1 + func matches(_ regex: RegEx) -> Bool { + let firstMatch = firstMatch(of: regex, options: .anchored) + return firstMatch != nil } func matches(pattern: String, options: NSRegularExpression.Options = [.caseInsensitive]) -> Bool { @@ -406,7 +406,7 @@ public extension String { return matches(regex) } - func replacing(_ regex: NSRegularExpression, with replacement: String) -> String { + func replacing(_ regex: RegEx, with replacement: String) -> String { regex.stringByReplacingMatches(in: self, range: self.fullRange, withTemplate: replacement) } diff --git a/Sources/MaliciousSiteProtection/API/APIClient.swift b/Sources/MaliciousSiteProtection/API/APIClient.swift index f4c08e446..4ed457a95 100644 --- a/Sources/MaliciousSiteProtection/API/APIClient.swift +++ b/Sources/MaliciousSiteProtection/API/APIClient.swift @@ -20,23 +20,21 @@ import Common import Foundation import Networking -public protocol APIClientProtocol { - func load(_ requestConfig: Request) async throws -> Request.ResponseType -} - -public extension APIClientProtocol where Self == APIClient { - static var production: APIClientProtocol { APIClient(environment: .production) } - static var staging: APIClientProtocol { APIClient(environment: .staging) } +extension APIClient { + // used internally for testing + protocol Mockable { + func load(_ requestConfig: Request) async throws -> Request.Response + } } +extension APIClient: APIClient.Mockable {} public protocol APIClientEnvironment { - func headers(for request: APIClient.Request) -> APIRequestV2.HeadersV2 - func url(for request: APIClient.Request) -> URL - func timeout(for request: APIClient.Request) -> TimeInterval + func headers(for requestType: APIRequestType) -> APIRequestV2.HeadersV2 + func url(for requestType: APIRequestType) -> URL } -public extension APIClient { - enum DefaultEnvironment: APIClientEnvironment { +public extension MaliciousSiteDetector { + enum APIEnvironment: APIClientEnvironment { case production case staging @@ -49,7 +47,7 @@ public extension APIClient { } var defaultHeaders: APIRequestV2.HeadersV2 { - .init(userAgent: APIRequest.Headers.userAgent) + .init(userAgent: Networking.APIRequest.Headers.userAgent) } enum APIPath { @@ -64,8 +62,8 @@ public extension APIClient { static let hashPrefix = "hashPrefix" } - public func url(for request: APIClient.Request) -> URL { - switch request { + public func url(for requestType: APIRequestType) -> URL { + switch requestType { case .hashPrefixSet(let configuration): endpoint.appendingPathComponent(APIPath.hashPrefix).appendingParameters([ QueryParameter.category: configuration.threatKind.rawValue, @@ -81,45 +79,31 @@ public extension APIClient { } } - public func headers(for request: APIClient.Request) -> APIRequestV2.HeadersV2 { + public func headers(for requestType: APIRequestType) -> APIRequestV2.HeadersV2 { defaultHeaders } - - public func timeout(for request: APIClient.Request) -> TimeInterval { - switch request { - case .hashPrefixSet, .filterSet: 60 - // This could block navigation so we should favour navigation loading if the backend is degraded. - // On Android we're looking at a maximum 1 second timeout for this request. - case .matches: 1 - } - } } } -public struct APIClient: APIClientProtocol { +struct APIClient { let environment: APIClientEnvironment private let service: APIService - public init(environment: Self.DefaultEnvironment = .production, service: APIService = DefaultAPIService(urlSession: .shared)) { - self.init(environment: environment as APIClientEnvironment, service: service) - } - - public init(environment: APIClientEnvironment, service: APIService) { + init(environment: APIClientEnvironment, service: APIService = DefaultAPIService(urlSession: .shared)) { self.environment = environment self.service = service } - public func load(_ requestConfig: Request) async throws -> Request.ResponseType { + func load(_ requestConfig: R) async throws -> R.Response { let requestType = requestConfig.requestType let headers = environment.headers(for: requestType) let url = environment.url(for: requestType) - let timeout = environment.timeout(for: requestType) - let apiRequest = APIRequestV2(url: url, method: .get, headers: headers, timeoutInterval: timeout) + let apiRequest = APIRequestV2(url: url, method: .get, headers: headers) let response = try await service.fetch(request: apiRequest) - let result: Request.ResponseType = try response.decodeBody() + let result: R.Response = try response.decodeBody() return result } @@ -127,18 +111,18 @@ public struct APIClient: APIClientProtocol { } // MARK: - Convenience -extension APIClientProtocol { - public func filtersChangeSet(for threatKind: ThreatKind, revision: Int) async throws -> APIClient.Response.FiltersChangeSet { +extension APIClient.Mockable { + func filtersChangeSet(for threatKind: ThreatKind, revision: Int) async throws -> APIClient.Response.FiltersChangeSet { let result = try await load(.filterSet(threatKind: threatKind, revision: revision)) return result } - public func hashPrefixesChangeSet(for threatKind: ThreatKind, revision: Int) async throws -> APIClient.Response.HashPrefixesChangeSet { + func hashPrefixesChangeSet(for threatKind: ThreatKind, revision: Int) async throws -> APIClient.Response.HashPrefixesChangeSet { let result = try await load(.hashPrefixes(threatKind: threatKind, revision: revision)) return result } - public func matches(forHashPrefix hashPrefix: String) async throws -> APIClient.Response.Matches { + func matches(forHashPrefix hashPrefix: String) async throws -> APIClient.Response.Matches { let result = try await load(.matches(hashPrefix: hashPrefix)) return result } diff --git a/Sources/MaliciousSiteProtection/API/APIRequest.swift b/Sources/MaliciousSiteProtection/API/APIRequest.swift index af18fb2f2..1efc8c01b 100644 --- a/Sources/MaliciousSiteProtection/API/APIRequest.swift +++ b/Sources/MaliciousSiteProtection/API/APIRequest.swift @@ -18,66 +18,88 @@ import Foundation -public protocol APIRequestProtocol { - associatedtype ResponseType: Decodable - var requestType: APIClient.Request { get } +// Enumerated request type to delegate URLs forming to an API environment instance +public enum APIRequestType { + case hashPrefixSet(APIRequestType.HashPrefixes) + case filterSet(APIRequestType.FilterSet) + case matches(APIRequestType.Matches) } -public extension APIClient { - enum Request { - case hashPrefixSet(HashPrefixes) - case filterSet(FilterSet) - case matches(Matches) +extension APIClient { + // Protocol for defining typed requests with a specific response type. + protocol Request { + associatedtype Response: Decodable // Strongly-typed response type + var requestType: APIRequestType { get } // Enumerated type of request being made + } + + // Protocol for requests that modify a set of malicious site detection data + // (returning insertions/removals along with the updated revision) + protocol ChangeSetRequest: Request { + init(threatKind: ThreatKind, revision: Int?) } } -public extension APIClient.Request { - struct HashPrefixes: APIRequestProtocol { - public typealias ResponseType = APIClient.Response.HashPrefixesChangeSet - public let threatKind: ThreatKind - public let revision: Int? +public extension APIRequestType { + struct HashPrefixes: APIClient.ChangeSetRequest { + typealias Response = APIClient.Response.HashPrefixesChangeSet - public var requestType: APIClient.Request { + let threatKind: ThreatKind + let revision: Int? + + init(threatKind: ThreatKind, revision: Int?) { + self.threatKind = threatKind + self.revision = revision + } + + var requestType: APIRequestType { .hashPrefixSet(self) } } } -extension APIRequestProtocol where Self == APIClient.Request.HashPrefixes { +/// extension to call generic `load(_: some Request)` method like this: `load(.hashPrefixes(…))` +extension APIClient.Request where Self == APIRequestType.HashPrefixes { static func hashPrefixes(threatKind: ThreatKind, revision: Int?) -> Self { .init(threatKind: threatKind, revision: revision) } } -public extension APIClient.Request { - struct FilterSet: APIRequestProtocol { - public typealias ResponseType = APIClient.Response.FiltersChangeSet +public extension APIRequestType { + struct FilterSet: APIClient.ChangeSetRequest { + typealias Response = APIClient.Response.FiltersChangeSet + + let threatKind: ThreatKind + let revision: Int? - public let threatKind: ThreatKind - public let revision: Int? + init(threatKind: ThreatKind, revision: Int?) { + self.threatKind = threatKind + self.revision = revision + } - public var requestType: APIClient.Request { + var requestType: APIRequestType { .filterSet(self) } } } -extension APIRequestProtocol where Self == APIClient.Request.FilterSet { +/// extension to call generic `load(_: some Request)` method like this: `load(.filterSet(…))` +extension APIClient.Request where Self == APIRequestType.FilterSet { static func filterSet(threatKind: ThreatKind, revision: Int?) -> Self { .init(threatKind: threatKind, revision: revision) } } -public extension APIClient.Request { - struct Matches: APIRequestProtocol { - public typealias ResponseType = APIClient.Response.Matches +public extension APIRequestType { + struct Matches: APIClient.Request { + typealias Response = APIClient.Response.Matches - public let hashPrefix: String + let hashPrefix: String - public var requestType: APIClient.Request { + var requestType: APIRequestType { .matches(self) } } } -extension APIRequestProtocol where Self == APIClient.Request.Matches { +/// extension to call generic `load(_: some Request)` method like this: `load(.matches(…))` +extension APIClient.Request where Self == APIRequestType.Matches { static func matches(hashPrefix: String) -> Self { .init(hashPrefix: hashPrefix) } diff --git a/Sources/MaliciousSiteProtection/API/ChangeSetResponse.swift b/Sources/MaliciousSiteProtection/API/ChangeSetResponse.swift index 7988c2a32..eaf4f287c 100644 --- a/Sources/MaliciousSiteProtection/API/ChangeSetResponse.swift +++ b/Sources/MaliciousSiteProtection/API/ChangeSetResponse.swift @@ -32,6 +32,10 @@ extension APIClient { self.revision = revision self.replace = replace } + + public var isEmpty: Bool { + insert.isEmpty && delete.isEmpty + } } public enum Response { diff --git a/Sources/MaliciousSiteProtection/Logger+MaliciousSiteProtection.swift b/Sources/MaliciousSiteProtection/Logger+MaliciousSiteProtection.swift index 15b47d8e1..3e44f3bcd 100644 --- a/Sources/MaliciousSiteProtection/Logger+MaliciousSiteProtection.swift +++ b/Sources/MaliciousSiteProtection/Logger+MaliciousSiteProtection.swift @@ -25,7 +25,6 @@ public extension os.Logger { public static var api = os.Logger(subsystem: "MSP", category: "API") public static var dataManager = os.Logger(subsystem: "MSP", category: "DataManager") public static var updateManager = os.Logger(subsystem: "MSP", category: "UpdateManager") - static var phishingDetectionTasks = os.Logger(subsystem: "MSP", category: "BackgroundActivities") } } diff --git a/Sources/MaliciousSiteProtection/MaliciousSiteDetector.swift b/Sources/MaliciousSiteProtection/MaliciousSiteDetector.swift index 592e7e852..d38fc4c12 100644 --- a/Sources/MaliciousSiteProtection/MaliciousSiteDetector.swift +++ b/Sources/MaliciousSiteProtection/MaliciousSiteDetector.swift @@ -21,93 +21,109 @@ import CryptoKit import Foundation public protocol MaliciousSiteDetecting { + /// Evaluates the given URL to determine its malicious category (e.g., phishing, malware). + /// - Parameter url: The URL to evaluate. + /// - Returns: An optional `ThreatKind` indicating the type of threat, or `.none` if no threat is detected. func evaluate(_ url: URL) async -> ThreatKind? } +/// Class responsible for detecting malicious sites by evaluating URLs against local filters and an external API. +/// entry point: `func evaluate(_: URL) async -> ThreatKind?` public final class MaliciousSiteDetector: MaliciousSiteDetecting { - // for easier Xcode symbol navigation + // Type aliases for easier symbol navigation in Xcode. typealias PhishingDetector = MaliciousSiteDetector typealias MalwareDetector = MaliciousSiteDetector - let hashPrefixStoreLength: Int = 8 - let hashPrefixParamLength: Int = 4 - let apiClient: APIClientProtocol - let dataManager: DataManaging - let eventMapping: EventMapping + private enum Constants { + static let hashPrefixStoreLength: Int = 8 + static let hashPrefixParamLength: Int = 4 + } + + private let apiClient: APIClient.Mockable + private let dataManager: DataManaging + private let eventMapping: EventMapping - public init(apiClient: APIClientProtocol = APIClient(), dataManager: DataManaging, eventMapping: EventMapping) { + public convenience init(apiEnvironment: APIClientEnvironment, dataManager: DataManager, eventMapping: EventMapping) { + self.init(apiClient: APIClient(environment: apiEnvironment), dataManager: dataManager, eventMapping: eventMapping) + } + + init(apiClient: APIClient.Mockable, dataManager: DataManaging, eventMapping: EventMapping) { self.apiClient = apiClient self.dataManager = dataManager self.eventMapping = eventMapping } - private func inFilterSet(hash: String) -> Set { - return Set(dataManager.filterSet.filter { $0.hash == hash }) - } - - private func matchesUrl(hash: String, regexPattern: String, url: URL, hostnameHash: String) -> Bool { - if hash == hostnameHash, - let regex = try? NSRegularExpression(pattern: regexPattern, options: []) { - let urlString = url.absoluteString - let range = NSRange(location: 0, length: urlString.utf16.count) - return regex.firstMatch(in: urlString, options: [], range: range) != nil - } - return false - } + private func checkLocalFilters(hostHash: String, canonicalUrl: URL, for threatKind: ThreatKind) async -> Bool { + let filterSet = await dataManager.dataSet(for: .filterSet(threatKind: threatKind)) + let matchesLocalFilters = filterSet[hostHash]?.contains(where: { regex in + canonicalUrl.absoluteString.matches(pattern: regex) + }) ?? false - private func generateHashPrefix(for canonicalHost: String, length: Int) -> String { - let hostnameHash = SHA256.hash(data: Data(canonicalHost.utf8)).map { String(format: "%02hhx", $0) }.joined() - return String(hostnameHash.prefix(length)) + return matchesLocalFilters } - private func fetchMatches(hashPrefix: String) async -> [Match] { + private func checkApiMatches(hostHash: String, canonicalUrl: URL) async -> Match? { + let hashPrefixParam = String(hostHash.prefix(Constants.hashPrefixParamLength)) + let matches: [Match] do { - let response = try await apiClient.matches(forHashPrefix: hashPrefix) - return response.matches + matches = try await apiClient.matches(forHashPrefix: hashPrefixParam).matches } catch { - Logger.api.error("Failed to fetch matches for hash prefix: \(hashPrefix): \(error.localizedDescription)") - return [] + Logger.general.error("Error fetching matches from API: \(error)") + return nil } - } - private func checkLocalFilters(canonicalHost: String, canonicalUrl: URL) -> Bool { - let hostnameHash = generateHashPrefix(for: canonicalHost, length: Int.max) - let filterHit = inFilterSet(hash: hostnameHash) - for filter in filterHit where matchesUrl(hash: filter.hash, regexPattern: filter.regex, url: canonicalUrl, hostnameHash: hostnameHash) { - eventMapping.fire(.errorPageShown(clientSideHit: true)) - return true + if let match = matches.first(where: { match in + match.hash == hostHash && canonicalUrl.absoluteString.matches(pattern: match.regex) + }) { + return match } - return false + return nil } - private func checkApiMatches(canonicalHost: String, canonicalUrl: URL) async -> Bool { - let hashPrefixParam = generateHashPrefix(for: canonicalHost, length: hashPrefixParamLength) - let matches = await fetchMatches(hashPrefix: hashPrefixParam) - let hostnameHash = generateHashPrefix(for: canonicalHost, length: Int.max) - for match in matches where matchesUrl(hash: match.hash, regexPattern: match.regex, url: canonicalUrl, hostnameHash: hostnameHash) { - eventMapping.fire(.errorPageShown(clientSideHit: false)) - return true + /// Evaluates the given URL to determine its malicious category (e.g., phishing, malware). + public func evaluate(_ url: URL) async -> ThreatKind? { + guard let canonicalHost = url.canonicalHost(), + let canonicalUrl = url.canonicalURL() else { return .none } + + let hostHash = canonicalHost.sha256 + let hashPrefix = String(hostHash.prefix(Constants.hashPrefixStoreLength)) + + // 1. Check for matching hash prefixes. + // The hash prefix list serves as a representation of the entire database: + // every malicious website will have a hash prefix that it collides with. + var hashPrefixMatchingThreatKinds = [ThreatKind]() + for threatKind in ThreatKind.allCases { // e.g., phishing, malware, etc. + let hashPrefixes = await dataManager.dataSet(for: .hashPrefixes(threatKind: threatKind)) + if hashPrefixes.contains(hashPrefix) { + hashPrefixMatchingThreatKinds.append(threatKind) + } } - return false - } - public func evaluate(_ url: URL) async -> ThreatKind? { - guard let canonicalHost = url.canonicalHost(), let canonicalUrl = url.canonicalURL() else { return .none } - - for threatKind in ThreatKind.allCases { - let hashPrefix = generateHashPrefix(for: canonicalHost, length: hashPrefixStoreLength) - if dataManager.hashPrefixes.contains(hashPrefix) { - // Check local filterSet first - if checkLocalFilters(canonicalHost: canonicalHost, canonicalUrl: canonicalUrl) { - return threatKind - } - // If nothing found, hit the API to get matches - if await checkApiMatches(canonicalHost: canonicalHost, canonicalUrl: canonicalUrl) { - return threatKind - } + // Return no threats if no matching hash prefixes are found in the database. + guard !hashPrefixMatchingThreatKinds.isEmpty else { return .none } + + // 2. Check local Filter Sets. + // The filter set acts as a local cache of some database entries, containing + // the 5000 most common threats (or those most likely to collide with daily + // browsing behaviors, based on Clickhouse's top 10k, ranked by Netcraft's risk rating). + for threatKind in hashPrefixMatchingThreatKinds { + let matches = await checkLocalFilters(hostHash: hostHash, canonicalUrl: canonicalUrl, for: threatKind) + if matches { + eventMapping.fire(.errorPageShown(clientSideHit: true, threatKind: threatKind)) + return threatKind } } + // 3. If no locally cached filters matched, we will still make a request to the API + // to check for potential matches on our backend. + let match = await checkApiMatches(hostHash: hostHash, canonicalUrl: canonicalUrl) + if let match { + let threatKind = match.category.flatMap(ThreatKind.init) ?? hashPrefixMatchingThreatKinds[0] + eventMapping.fire(.errorPageShown(clientSideHit: false, threatKind: threatKind)) + return threatKind + } + return .none } + } diff --git a/Sources/MaliciousSiteProtection/Model/Event.swift b/Sources/MaliciousSiteProtection/Model/Event.swift index 31eab462a..8903f4d70 100644 --- a/Sources/MaliciousSiteProtection/Model/Event.swift +++ b/Sources/MaliciousSiteProtection/Model/Event.swift @@ -27,7 +27,7 @@ public extension PixelKit { } public enum Event: PixelKitEventV2 { - case errorPageShown(clientSideHit: Bool) + case errorPageShown(clientSideHit: Bool, threatKind: ThreatKind) case visitSite case iframeLoaded case updateTaskFailed48h(error: Error?) @@ -50,7 +50,7 @@ public enum Event: PixelKitEventV2 { public var parameters: [String: String]? { switch self { - case .errorPageShown(let clientSideHit): + case .errorPageShown(let clientSideHit, threatKind: _): return [PixelKit.Parameters.clientSideHit: String(clientSideHit)] case .visitSite: return [:] diff --git a/Sources/MaliciousSiteProtection/Model/FilterDictionary.swift b/Sources/MaliciousSiteProtection/Model/FilterDictionary.swift new file mode 100644 index 000000000..b67cd82ef --- /dev/null +++ b/Sources/MaliciousSiteProtection/Model/FilterDictionary.swift @@ -0,0 +1,78 @@ +// +// FilterDictionary.swift +// +// Copyright © 2024 DuckDuckGo. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +import Foundation + +struct FilterDictionary: Codable, Equatable { + + /// Filter set revision + var revision: Int + + /// [Hash: [RegEx]] mapping + /// + /// - **Key**: SHA256 hash sum of a canonical host name + /// - **Value**: An array of regex patterns used to match whole URLs + /// + /// ``` + /// { + /// "3aeb002460381c6f258e8395d3026f571f0d9a76488dcd837639b13aed316560" : [ + /// "(?i)^https?\\:\\/\\/[\\w\\-\\.]+(?:\\:(?:80|443))?[\\/\\\\]+BETS1O\\-GIRIS[\\/\\\\]+BETS1O(?:[\\/\\\\]+|\\?|$)" + /// ], + /// ... + /// } + /// ``` + var filters: [String: Set] + + /// Subscript to access regex patterns by SHA256 host name hash + subscript(hash: String) -> Set? { + filters[hash] + } + + mutating func subtract(_ itemsToDelete: Seq) where Seq.Element == Filter { + for filter in itemsToDelete { + // Remove the filter from the Set stored in the Dictionary by hash used as a key. + // If the Set becomes empty – remove the Set value from the Dictionary. + // + // The following code is equivalent to this one but without the Set value being copied + // or key being searched multiple times: + /* + if var filterSet = self.filters[filter.hash] { + filterSet.remove(filter.regex) + if filterSet.isEmpty { + self.filters[filter.hash] = nil + } else { + self.filters[filter.hash] = filterSet + } + } + */ + withUnsafeMutablePointer(to: &filters[filter.hash]) { item in + item.pointee?.remove(filter.regex) + if item.pointee?.isEmpty == true { + item.pointee = nil + } + } + } + } + + mutating func formUnion(_ itemsToAdd: Seq) where Seq.Element == Filter { + for filter in itemsToAdd { + filters[filter.hash, default: []].insert(filter.regex) + } + } + +} diff --git a/Sources/MaliciousSiteProtection/Model/HashPrefixSet.swift b/Sources/MaliciousSiteProtection/Model/HashPrefixSet.swift new file mode 100644 index 000000000..7aec5244d --- /dev/null +++ b/Sources/MaliciousSiteProtection/Model/HashPrefixSet.swift @@ -0,0 +1,45 @@ +// +// HashPrefixSet.swift +// +// Copyright © 2024 DuckDuckGo. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +import Foundation + +/// Structure storing a Set of hash prefixes ["6fe1e7c8","1d760415",...] and a revision of the set. +struct HashPrefixSet: Codable, Equatable { + + var revision: Int + var set: Set + + init(revision: Int, items: some Sequence) { + self.revision = revision + self.set = Set(items) + } + + mutating func subtract(_ itemsToDelete: Seq) where Seq.Element == String { + set.subtract(itemsToDelete) + } + + mutating func formUnion(_ itemsToAdd: Seq) where Seq.Element == String { + set.formUnion(itemsToAdd) + } + + @inline(__always) + func contains(_ item: String) -> Bool { + set.contains(item) + } + +} diff --git a/Sources/MaliciousSiteProtection/Model/IncrementallyUpdatableDataSet.swift b/Sources/MaliciousSiteProtection/Model/IncrementallyUpdatableDataSet.swift new file mode 100644 index 000000000..8a23785ae --- /dev/null +++ b/Sources/MaliciousSiteProtection/Model/IncrementallyUpdatableDataSet.swift @@ -0,0 +1,71 @@ +// +// IncrementallyUpdatableDataSet.swift +// +// Copyright © 2024 DuckDuckGo. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +protocol IncrementallyUpdatableDataSet: Codable, Equatable { + /// Set Element Type (Hash Prefix or Filter) + associatedtype Element: Codable, Hashable + /// API Request type used to fetch updates for the data set + associatedtype APIRequest: APIClient.ChangeSetRequest where APIRequest.Response == APIClient.ChangeSetResponse + + var revision: Int { get set } + + init(revision: Int, items: some Sequence) + + mutating func subtract(_ itemsToDelete: Seq) where Seq.Element == Element + mutating func formUnion(_ itemsToAdd: Seq) where Seq.Element == Element + + /// Apply ChangeSet from local data revision to actual revision loaded from API + mutating func apply(_ changeSet: APIClient.ChangeSetResponse) +} + +extension IncrementallyUpdatableDataSet { + mutating func apply(_ changeSet: APIClient.ChangeSetResponse) { + if changeSet.replace { + self = .init(revision: changeSet.revision, items: changeSet.insert) + } else { + self.subtract(changeSet.delete) + self.formUnion(changeSet.insert) + self.revision = changeSet.revision + } + } +} + +extension HashPrefixSet: IncrementallyUpdatableDataSet { + typealias Element = String + typealias APIRequest = APIRequestType.HashPrefixes + + static func apiRequest(for threatKind: ThreatKind, revision: Int) -> APIRequest { + .hashPrefixes(threatKind: threatKind, revision: revision) + } +} + +extension FilterDictionary: IncrementallyUpdatableDataSet { + typealias Element = Filter + typealias APIRequest = APIRequestType.FilterSet + + init(revision: Int, items: some Sequence) { + let filtersDictionary = items.reduce(into: [String: Set]()) { result, filter in + result[filter.hash, default: []].insert(filter.regex) + } + self.init(revision: revision, filters: filtersDictionary) + } + + static func apiRequest(for threatKind: ThreatKind, revision: Int) -> APIRequest { + .filterSet(threatKind: threatKind, revision: revision) + } +} diff --git a/Tests/MaliciousSiteProtectionTests/Mocks/BackgroundActivitySchedulerMock.swift b/Sources/MaliciousSiteProtection/Model/LoadableFromEmbeddedData.swift similarity index 51% rename from Tests/MaliciousSiteProtectionTests/Mocks/BackgroundActivitySchedulerMock.swift rename to Sources/MaliciousSiteProtection/Model/LoadableFromEmbeddedData.swift index 6f2f8a20a..be67cb6fc 100644 --- a/Tests/MaliciousSiteProtectionTests/Mocks/BackgroundActivitySchedulerMock.swift +++ b/Sources/MaliciousSiteProtection/Model/LoadableFromEmbeddedData.swift @@ -1,5 +1,5 @@ // -// BackgroundActivitySchedulerMock.swift +// LoadableFromEmbeddedData.swift // // Copyright © 2024 DuckDuckGo. All rights reserved. // @@ -16,19 +16,19 @@ // limitations under the License. // -import Foundation -import MaliciousSiteProtection -actor MockBackgroundActivityScheduler: BackgroundActivityScheduling { - var startCalled = false - var stopCalled = false - var interval: TimeInterval = 1 - var identifier: String = "test" +public protocol LoadableFromEmbeddedData { + /// Set Element Type (Hash Prefix or Filter) + associatedtype Element + /// Decoded data type stored in the embedded json file + associatedtype EmbeddedDataSet: Decodable, Sequence where EmbeddedDataSet.Element == Self.Element - func start() { - startCalled = true - } + init(revision: Int, items: some Sequence) +} + +extension HashPrefixSet: LoadableFromEmbeddedData { + public typealias EmbeddedDataSet = [String] +} - func stop() { - stopCalled = true - } +extension FilterDictionary: LoadableFromEmbeddedData { + public typealias EmbeddedDataSet = [Filter] } diff --git a/Sources/MaliciousSiteProtection/Model/StoredData.swift b/Sources/MaliciousSiteProtection/Model/StoredData.swift new file mode 100644 index 000000000..a064be076 --- /dev/null +++ b/Sources/MaliciousSiteProtection/Model/StoredData.swift @@ -0,0 +1,104 @@ +// +// StoredData.swift +// +// Copyright © 2024 DuckDuckGo. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +import Foundation + +protocol MaliciousSiteDataKey: Hashable { + associatedtype EmbeddedDataSet: Decodable + associatedtype DataSet: IncrementallyUpdatableDataSet, LoadableFromEmbeddedData + + var dataType: DataManager.StoredDataType { get } + var threatKind: ThreatKind { get } +} + +public extension DataManager { + enum StoredDataType: Hashable, CaseIterable { + case hashPrefixSet(HashPrefixes) + case filterSet(FilterSet) + + enum Kind: CaseIterable { + case hashPrefixSet, filterSet + } + // keep to get a compiler error when number of cases changes + var kind: Kind { + switch self { + case .hashPrefixSet: .hashPrefixSet + case .filterSet: .filterSet + } + } + + var dataKey: any MaliciousSiteDataKey { + switch self { + case .hashPrefixSet(let key): key + case .filterSet(let key): key + } + } + + public var threatKind: ThreatKind { + switch self { + case .hashPrefixSet(let key): key.threatKind + case .filterSet(let key): key.threatKind + } + } + + public static var allCases: [DataManager.StoredDataType] { + ThreatKind.allCases.map { threatKind in + Kind.allCases.map { dataKind in + switch dataKind { + case .hashPrefixSet: .hashPrefixSet(.init(threatKind: threatKind)) + case .filterSet: .filterSet(.init(threatKind: threatKind)) + } + } + }.flatMap { $0 } + } + } +} + +public extension DataManager.StoredDataType { + struct HashPrefixes: MaliciousSiteDataKey { + typealias DataSet = HashPrefixSet + + let threatKind: ThreatKind + + var dataType: DataManager.StoredDataType { + .hashPrefixSet(self) + } + } +} +extension MaliciousSiteDataKey where Self == DataManager.StoredDataType.HashPrefixes { + static func hashPrefixes(threatKind: ThreatKind) -> Self { + .init(threatKind: threatKind) + } +} + +public extension DataManager.StoredDataType { + struct FilterSet: MaliciousSiteDataKey { + typealias DataSet = FilterDictionary + + let threatKind: ThreatKind + + var dataType: DataManager.StoredDataType { + .filterSet(self) + } + } +} +extension MaliciousSiteDataKey where Self == DataManager.StoredDataType.FilterSet { + static func filterSet(threatKind: ThreatKind) -> Self { + .init(threatKind: threatKind) + } +} diff --git a/Sources/MaliciousSiteProtection/PhishingDetectionDataActivities.swift b/Sources/MaliciousSiteProtection/PhishingDetectionDataActivities.swift deleted file mode 100644 index bcf06d518..000000000 --- a/Sources/MaliciousSiteProtection/PhishingDetectionDataActivities.swift +++ /dev/null @@ -1,106 +0,0 @@ -// -// PhishingDetectionDataActivities.swift -// -// Copyright © 2024 DuckDuckGo. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// - -import Foundation -import Common -import os - -public protocol BackgroundActivityScheduling: Actor { - func start() - func stop() -} -actor BackgroundActivityScheduler: BackgroundActivityScheduling { - - private var task: Task? - private var timer: Timer? - private let interval: TimeInterval - private let identifier: String - private let activity: () async -> Void - - init(interval: TimeInterval, identifier: String, activity: @escaping () async -> Void) { - self.interval = interval - self.identifier = identifier - self.activity = activity - } - - func start() { - stop() - task = Task { - let taskId = UUID().uuidString - while !Task.isCancelled { - await activity() - do { - Logger.phishingDetectionTasks.debug("🟢 \(self.identifier) task was executed in instance \(taskId)") - try await Task.sleep(nanoseconds: UInt64(interval * 1_000_000_000)) - } catch { - Logger.phishingDetectionTasks.error("🔴 Error \(self.identifier) task was cancelled before it could finish sleeping.") - break - } - } - } - } - - func stop() { - task?.cancel() - task = nil - } -} - -public protocol PhishingDetectionDataActivityHandling { - func start() - func stop() -} - -public class PhishingDetectionDataActivities: PhishingDetectionDataActivityHandling { - private var schedulers: [BackgroundActivityScheduler] - private var running: Bool = false - - public init(hashPrefixInterval: TimeInterval = 20 * 60, filterSetInterval: TimeInterval = 12 * 60 * 60, updateManager: UpdateManaging) { - let hashPrefixScheduler = BackgroundActivityScheduler( - interval: hashPrefixInterval, - identifier: "hashPrefixes.update", - activity: { await updateManager.updateHashPrefixes() } - ) - let filterSetScheduler = BackgroundActivityScheduler( - interval: filterSetInterval, - identifier: "filterSet.update", - activity: { await updateManager.updateFilterSet() } - ) - self.schedulers = [hashPrefixScheduler, filterSetScheduler] - } - - public func start() { - if !running { - Task { - for scheduler in schedulers { - await scheduler.start() - } - } - running = true - } - } - - public func stop() { - Task { - for scheduler in schedulers { - await scheduler.stop() - } - } - running = false - } -} diff --git a/Sources/MaliciousSiteProtection/Services/DataManager.swift b/Sources/MaliciousSiteProtection/Services/DataManager.swift index 41b2dba9d..8e4426dd1 100644 --- a/Sources/MaliciousSiteProtection/Services/DataManager.swift +++ b/Sources/MaliciousSiteProtection/Services/DataManager.swift @@ -17,164 +17,89 @@ // import Foundation -import Common import os -public protocol DataManaging { - var filterSet: Set { get } - var hashPrefixes: Set { get } - var currentRevision: Int { get } - func saveFilterSet(set: Set) - func saveHashPrefixes(set: Set) - func saveRevision(_ revision: Int) +protocol DataManaging { + func dataSet(for key: DataKey) async -> DataKey.DataSet + func store(_ dataSet: DataKey.DataSet, for key: DataKey) async } -public final class DataManager: DataManaging { - private lazy var _filterSet: Set = { - loadFilterSet() - }() - - private lazy var _hashPrefixes: Set = { - loadHashPrefix() - }() - - private lazy var _currentRevision: Int = { - loadRevision() - }() - - public private(set) var filterSet: Set { - get { _filterSet } - set { _filterSet = newValue } - } - public private(set) var hashPrefixes: Set { - get { _hashPrefixes } - set { _hashPrefixes = newValue } - } - public private(set) var currentRevision: Int { - get { _currentRevision } - set { _currentRevision = newValue } - } +public actor DataManager: DataManaging { private let embeddedDataProvider: EmbeddedDataProviding private let fileStore: FileStoring - private let encoder = JSONEncoder() - private let revisionFilename = "revision.txt" - private let hashPrefixFilename = "phishingHashPrefixes.json" - private let filterSetFilename = "phishingFilterSet.json" - public init(embeddedDataProvider: EmbeddedDataProviding, fileStore: FileStoring? = nil) { + public typealias FileNameProvider = (DataManager.StoredDataType) -> String + private nonisolated let fileNameProvider: FileNameProvider + + private var store: [StoredDataType: Any] = [:] + + public init(fileStore: FileStoring, embeddedDataProvider: EmbeddedDataProviding, fileNameProvider: @escaping FileNameProvider) { self.embeddedDataProvider = embeddedDataProvider - self.fileStore = fileStore ?? FileStore() + self.fileStore = fileStore + self.fileNameProvider = fileNameProvider } - private func writeHashPrefixes() { - let encoder = JSONEncoder() - do { - let hashPrefixesData = try encoder.encode(Array(hashPrefixes)) - fileStore.write(data: hashPrefixesData, to: hashPrefixFilename) - } catch { - Logger.dataManager.error("Error saving hash prefixes data: \(error.localizedDescription)") + func dataSet(for key: DataKey) -> DataKey.DataSet { + let dataType = key.dataType + // return cached dataSet if available + if let data = store[key.dataType] as? DataKey.DataSet { + return data } - } - private func writeFilterSet() { - let encoder = JSONEncoder() - do { - let filterSetData = try encoder.encode(Array(filterSet)) - fileStore.write(data: filterSetData, to: filterSetFilename) - } catch { - Logger.dataManager.error("Error saving filter set data: \(error.localizedDescription)") - } - } + // read stored dataSet if it‘s newer than the embedded one + let dataSet = readStoredDataSet(for: key) ?? { + // no stored dataSet or the embedded one is newer + let embeddedRevision = embeddedDataProvider.revision(for: dataType) + let embeddedItems = embeddedDataProvider.loadDataSet(for: key) + return .init(revision: embeddedRevision, items: embeddedItems) + }() - private func writeRevision() { - let encoder = JSONEncoder() - do { - let revisionData = try encoder.encode(currentRevision) - fileStore.write(data: revisionData, to: revisionFilename) - } catch { - Logger.dataManager.error("Error saving revision data: \(error.localizedDescription)") - } - } + // cache + store[dataType] = dataSet - private func loadHashPrefix() -> Set { - guard let data = fileStore.read(from: hashPrefixFilename) else { - return embeddedDataProvider.loadEmbeddedHashPrefixes() - } - let decoder = JSONDecoder() - do { - if loadRevisionFromDisk() < embeddedDataProvider.embeddedRevision { - return embeddedDataProvider.loadEmbeddedHashPrefixes() - } - let onDiskHashPrefixes = Set(try decoder.decode(Set.self, from: data)) - return onDiskHashPrefixes - } catch { - Logger.dataManager.error("Error decoding \(self.hashPrefixFilename): \(error.localizedDescription)") - return embeddedDataProvider.loadEmbeddedHashPrefixes() - } + return dataSet } - private func loadFilterSet() -> Set { - guard let data = fileStore.read(from: filterSetFilename) else { - return embeddedDataProvider.loadEmbeddedFilterSet() - } - let decoder = JSONDecoder() + private func readStoredDataSet(for key: DataKey) -> DataKey.DataSet? { + let dataType = key.dataType + let fileName = fileNameProvider(dataType) + guard let data = fileStore.read(from: fileName) else { return nil } + + let storedDataSet: DataKey.DataSet do { - if loadRevisionFromDisk() < embeddedDataProvider.embeddedRevision { - return embeddedDataProvider.loadEmbeddedFilterSet() - } - let onDiskFilterSet = Set(try decoder.decode(Set.self, from: data)) - return onDiskFilterSet + storedDataSet = try JSONDecoder().decode(DataKey.DataSet.self, from: data) } catch { - Logger.dataManager.error("Error decoding \(self.filterSetFilename): \(error.localizedDescription)") - return embeddedDataProvider.loadEmbeddedFilterSet() + Logger.dataManager.error("Error decoding \(fileName): \(error.localizedDescription)") + return nil } - } - private func loadRevisionFromDisk() -> Int { - guard let data = fileStore.read(from: revisionFilename) else { - return embeddedDataProvider.embeddedRevision - } - let decoder = JSONDecoder() - do { - return try decoder.decode(Int.self, from: data) - } catch { - Logger.dataManager.error("Error decoding \(self.revisionFilename): \(error.localizedDescription)") - return embeddedDataProvider.embeddedRevision + // compare to the embedded data revision + let embeddedDataRevision = embeddedDataProvider.revision(for: dataType) + guard storedDataSet.revision >= embeddedDataRevision else { + Logger.dataManager.error("Stored \(fileName) is outdated: revision: \(storedDataSet.revision), embedded revision: \(embeddedDataRevision).") + return nil } + + return storedDataSet } - private func loadRevision() -> Int { - guard let data = fileStore.read(from: revisionFilename) else { - return embeddedDataProvider.embeddedRevision - } - let decoder = JSONDecoder() + func store(_ dataSet: DataKey.DataSet, for key: DataKey) { + let dataType = key.dataType + let fileName = fileNameProvider(dataType) + self.store[dataType] = dataSet + + let data: Data do { - let loadedRevision = try decoder.decode(Int.self, from: data) - if loadedRevision < embeddedDataProvider.embeddedRevision { - return embeddedDataProvider.embeddedRevision - } - return loadedRevision + data = try JSONEncoder().encode(dataSet) } catch { - Logger.dataManager.error("Error decoding \(self.revisionFilename): \(error.localizedDescription)") - return embeddedDataProvider.embeddedRevision + Logger.dataManager.error("Error encoding \(fileName): \(error.localizedDescription)") + assertionFailure("Failed to store data to \(fileName): \(error)") + return } - } -} - -extension DataManager { - public func saveFilterSet(set: Set) { - self.filterSet = set - writeFilterSet() - } - public func saveHashPrefixes(set: Set) { - self.hashPrefixes = set - writeHashPrefixes() + let success = fileStore.write(data: data, to: fileName) + assert(success) } - public func saveRevision(_ revision: Int) { - self.currentRevision = revision - writeRevision() - } } diff --git a/Sources/MaliciousSiteProtection/Services/EmbeddedDataProvider.swift b/Sources/MaliciousSiteProtection/Services/EmbeddedDataProvider.swift index 5ca4d9f7c..c9c82c2a0 100644 --- a/Sources/MaliciousSiteProtection/Services/EmbeddedDataProvider.swift +++ b/Sources/MaliciousSiteProtection/Services/EmbeddedDataProvider.swift @@ -16,58 +16,41 @@ // limitations under the License. // -import Foundation import CryptoKit -import Common -import os +import Foundation public protocol EmbeddedDataProviding { - var embeddedRevision: Int { get } - func loadEmbeddedFilterSet() -> Set - func loadEmbeddedHashPrefixes() -> Set -} - -public struct EmbeddedDataProvider: EmbeddedDataProviding { - public let embeddedRevision: Int - private let embeddedFilterSetURL: URL - private let embeddedFilterSetDataSHA: String - private let embeddedHashPrefixURL: URL - private let embeddedHashPrefixDataSHA: String + func revision(for dataType: DataManager.StoredDataType) -> Int + func url(for dataType: DataManager.StoredDataType) -> URL + func hash(for dataType: DataManager.StoredDataType) -> String - public init(revision: Int, filterSetURL: URL, filterSetDataSHA: String, hashPrefixURL: URL, hashPrefixDataSHA: String) { - embeddedFilterSetURL = filterSetURL - embeddedFilterSetDataSHA = filterSetDataSHA - embeddedHashPrefixURL = hashPrefixURL - embeddedHashPrefixDataSHA = hashPrefixDataSHA - embeddedRevision = revision - } + func data(withContentsOf url: URL) throws -> Data +} - private func loadData(from url: URL, expectedSHA: String) throws -> Data { - let data = try Data(contentsOf: url) - let sha256 = SHA256.hash(data: data) - let hashString = sha256.compactMap { String(format: "%02x", $0) }.joined() +extension EmbeddedDataProviding { - guard hashString == expectedSHA else { - throw NSError(domain: "PhishingDetectionDataProvider", code: 1001, userInfo: [NSLocalizedDescriptionKey: "SHA mismatch"]) + func loadDataSet(for key: DataKey) -> DataKey.EmbeddedDataSet { + let dataType = key.dataType + let url = url(for: dataType) + let data: Data + do { + data = try self.data(withContentsOf: url) +#if DEBUG + assert(data.sha256 == hash(for: dataType), "SHA mismatch for \(url.path)") +#endif + } catch { + fatalError("\(self): Could not load embedded data set at “\(url)”: \(error)") } - return data - } - - public func loadEmbeddedFilterSet() -> Set { - do { - let filterSetData = try loadData(from: embeddedFilterSetURL, expectedSHA: embeddedFilterSetDataSHA) - return try JSONDecoder().decode(Set.self, from: filterSetData) - } catch { - fatalError("🔴 Error: SHA mismatch for filterSet JSON file. Expected \(self.embeddedFilterSetDataSHA)") - } - } - - public func loadEmbeddedHashPrefixes() -> Set { do { - let hashPrefixData = try loadData(from: embeddedHashPrefixURL, expectedSHA: embeddedHashPrefixDataSHA) - return try JSONDecoder().decode(Set.self, from: hashPrefixData) + let result = try JSONDecoder().decode(DataKey.EmbeddedDataSet.self, from: data) + return result } catch { - fatalError("🔴 Error: SHA mismatch for hashPrefixes JSON file. Expected \(self.embeddedHashPrefixDataSHA)") + fatalError("\(self): Could not decode embedded data set at “\(url)”: \(error)") } } + + public func data(withContentsOf url: URL) throws -> Data { + try Data(contentsOf: url) + } + } diff --git a/Sources/MaliciousSiteProtection/Services/FileStore.swift b/Sources/MaliciousSiteProtection/Services/FileStore.swift index e0714401a..06418e6a2 100644 --- a/Sources/MaliciousSiteProtection/Services/FileStore.swift +++ b/Sources/MaliciousSiteProtection/Services/FileStore.swift @@ -20,22 +20,15 @@ import Foundation import os public protocol FileStoring { - func write(data: Data, to filename: String) + @discardableResult func write(data: Data, to filename: String) -> Bool func read(from filename: String) -> Data? } -public struct FileStore: FileStoring { +public struct FileStore: FileStoring, CustomDebugStringConvertible { private let dataStoreURL: URL - public init() { - let dataStoreDirectory: URL - do { - dataStoreDirectory = try FileManager.default.url(for: .applicationSupportDirectory, in: .userDomainMask, appropriateFor: nil, create: true) - } catch { - Logger.dataManager.error("Error accessing application support directory: \(error.localizedDescription)") - dataStoreDirectory = FileManager.default.temporaryDirectory - } - dataStoreURL = dataStoreDirectory.appendingPathComponent(Bundle.main.bundleIdentifier!, isDirectory: true) + public init(dataStoreURL: URL) { + self.dataStoreURL = dataStoreURL createDirectoryIfNeeded() } @@ -47,12 +40,14 @@ public struct FileStore: FileStoring { } } - public func write(data: Data, to filename: String) { + public func write(data: Data, to filename: String) -> Bool { let fileURL = dataStoreURL.appendingPathComponent(filename) do { try data.write(to: fileURL) + return true } catch { Logger.dataManager.error("Error writing to directory: \(error.localizedDescription)") + return false } } @@ -65,4 +60,8 @@ public struct FileStore: FileStoring { return nil } } + + public var debugDescription: String { + return "<\(type(of: self)) - \"\(dataStoreURL.path)\">" + } } diff --git a/Sources/MaliciousSiteProtection/Services/UpdateManager.swift b/Sources/MaliciousSiteProtection/Services/UpdateManager.swift index af9f60e7d..2b91ac051 100644 --- a/Sources/MaliciousSiteProtection/Services/UpdateManager.swift +++ b/Sources/MaliciousSiteProtection/Services/UpdateManager.swift @@ -20,76 +20,78 @@ import Foundation import Common import os -public protocol UpdateManaging { - func updateFilterSet() async - func updateHashPrefixes() async +protocol UpdateManaging { + func updateData(for key: some MaliciousSiteDataKey) async + + func startPeriodicUpdates() -> Task } public struct UpdateManager: UpdateManaging { - private let apiClient: APIClientProtocol + + private let apiClient: APIClient.Mockable private let dataManager: DataManaging - public init(apiClient: APIClientProtocol, dataManager: DataManaging) { + public typealias UpdateIntervalProvider = (DataManager.StoredDataType) -> TimeInterval? + private let updateIntervalProvider: UpdateIntervalProvider + private let sleeper: Sleeper + + public init(apiEnvironment: APIClientEnvironment, dataManager: DataManager, updateIntervalProvider: @escaping UpdateIntervalProvider) { + self.init(apiClient: APIClient(environment: apiEnvironment), dataManager: dataManager, updateIntervalProvider: updateIntervalProvider) + } + + init(apiClient: APIClient.Mockable, dataManager: DataManaging, sleeper: Sleeper = .default, updateIntervalProvider: @escaping UpdateIntervalProvider) { self.apiClient = apiClient self.dataManager = dataManager + self.updateIntervalProvider = updateIntervalProvider + self.sleeper = sleeper } - private func updateSet( - currentSet: Set, - insert: [T], - delete: [T], - replace: Bool, - saveSet: (Set) -> Void - ) { - var newSet = currentSet - - if replace { - newSet = Set(insert) - } else { - newSet.formUnion(insert) - newSet.subtract(delete) - } - - saveSet(newSet) - } + func updateData(for key: DataKey) async { + // load currently stored data set + var dataSet = await dataManager.dataSet(for: key) + let oldRevision = dataSet.revision - public func updateFilterSet() async { - let changeSet: APIClient.Response.FiltersChangeSet + // get change set from current revision from API + let changeSet: APIClient.ChangeSetResponse do { - changeSet = try await apiClient.filtersChangeSet(for: .phishing, revision: dataManager.currentRevision) + let request = DataKey.DataSet.APIRequest(threatKind: key.threatKind, revision: oldRevision) + changeSet = try await apiClient.load(request) } catch { Logger.updateManager.error("error fetching filter set: \(error)") return } - updateSet( - currentSet: dataManager.filterSet, - insert: changeSet.insert, - delete: changeSet.delete, - replace: changeSet.replace - ) { newSet in - self.dataManager.saveFilterSet(set: newSet) + guard !changeSet.isEmpty || changeSet.revision != dataSet.revision else { + Logger.updateManager.debug("no changes to filter set") + return } - dataManager.saveRevision(changeSet.revision) - Logger.updateManager.debug("filterSet updated to revision \(self.dataManager.currentRevision)") + + // apply changes + dataSet.apply(changeSet) + + // store back + await self.dataManager.store(dataSet, for: key) + Logger.updateManager.debug("\(type(of: key)).\(key.threatKind) updated from rev.\(oldRevision) to rev.\(dataSet.revision)") } - public func updateHashPrefixes() async { - let changeSet: APIClient.Response.HashPrefixesChangeSet - do { - changeSet = try await apiClient.hashPrefixesChangeSet(for: .phishing, revision: dataManager.currentRevision) - } catch { - Logger.updateManager.error("error fetching hash prefixes: \(error)") - return - } - updateSet( - currentSet: dataManager.hashPrefixes, - insert: changeSet.insert, - delete: changeSet.delete, - replace: changeSet.replace - ) { newSet in - self.dataManager.saveHashPrefixes(set: newSet) + public func startPeriodicUpdates() -> Task { + Task.detached { + // run update jobs in background for every data type + try await withThrowingTaskGroup(of: Never.self) { group in + for dataType in DataManager.StoredDataType.allCases { + // get update interval from provider + guard let updateInterval = updateIntervalProvider(dataType) else { continue } + assert(updateInterval > 0) + + group.addTask { + // run periodically until the parent task is cancelled + try await performPeriodicJob(interval: updateInterval, sleeper: sleeper) { + await self.updateData(for: dataType.dataKey) + } + } + } + for try await _ in group {} + } } - dataManager.saveRevision(changeSet.revision) - Logger.updateManager.debug("hashPrefixes updated to revision \(self.dataManager.currentRevision)") } + } diff --git a/Tests/CommonTests/Extensions/StringExtensionTests.swift b/Tests/CommonTests/Extensions/StringExtensionTests.swift index bcf415895..65abb79c8 100644 --- a/Tests/CommonTests/Extensions/StringExtensionTests.swift +++ b/Tests/CommonTests/Extensions/StringExtensionTests.swift @@ -16,8 +16,10 @@ // limitations under the License. // +import CryptoKit import Foundation import XCTest + @testable import Common final class StringExtensionTests: XCTestCase { @@ -370,4 +372,13 @@ final class StringExtensionTests: XCTestCase { } } + func testSha256() { + let string = "Hello, World! This is a test string." + let hash = string.sha256 + let expected = "3c2b805ab0038afb0629e1d598ae73e0caabb69de03e96762977d34e8ba428bf" + let expectedSHA256 = SHA256.hash(data: Data(string.utf8)).map { String(format: "%02hhx", $0) }.joined() + XCTAssertEqual(hash, expected) + XCTAssertEqual(hash, expectedSHA256) + } + } diff --git a/Tests/MaliciousSiteProtectionTests/BackgroundActivitySchedulerTests.swift b/Tests/MaliciousSiteProtectionTests/BackgroundActivitySchedulerTests.swift deleted file mode 100644 index 0640fc16f..000000000 --- a/Tests/MaliciousSiteProtectionTests/BackgroundActivitySchedulerTests.swift +++ /dev/null @@ -1,57 +0,0 @@ -// -// BackgroundActivitySchedulerTests.swift -// -// Copyright © 2024 DuckDuckGo. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// -import Foundation -import XCTest -@testable import MaliciousSiteProtection - -class BackgroundActivitySchedulerTests: XCTestCase { - var scheduler: BackgroundActivityScheduler! - var activityWasRun = false - - override func tearDown() { - scheduler = nil - super.tearDown() - } - - func testStart() async throws { - let expectation = self.expectation(description: "Activity should run") - scheduler = BackgroundActivityScheduler(interval: 1, identifier: "test") { - if !self.activityWasRun { - self.activityWasRun = true - expectation.fulfill() - } - } - await scheduler.start() - await fulfillment(of: [expectation], timeout: 2) - XCTAssertTrue(activityWasRun) - } - - func testRepeats() async throws { - let expectation = self.expectation(description: "Activity should repeat") - var runCount = 0 - scheduler = BackgroundActivityScheduler(interval: 1, identifier: "test") { - runCount += 1 - if runCount == 2 { - expectation.fulfill() - } - } - await scheduler.start() - await fulfillment(of: [expectation], timeout: 3) - XCTAssertEqual(runCount, 2) - } -} diff --git a/Tests/MaliciousSiteProtectionTests/MaliciousSiteDetectorTests.swift b/Tests/MaliciousSiteProtectionTests/MaliciousSiteDetectorTests.swift index dd633be54..f6b0de23a 100644 --- a/Tests/MaliciousSiteProtectionTests/MaliciousSiteDetectorTests.swift +++ b/Tests/MaliciousSiteProtectionTests/MaliciousSiteDetectorTests.swift @@ -27,26 +27,24 @@ class MaliciousSiteDetectorTests: XCTestCase { private var mockEventMapping: MockEventMapping! private var detector: MaliciousSiteDetector! - override func setUp() { - super.setUp() + override func setUp() async throws { mockAPIClient = MockMaliciousSiteProtectionAPIClient() mockDataManager = MockMaliciousSiteProtectionDataManager() mockEventMapping = MockEventMapping() detector = MaliciousSiteDetector(apiClient: mockAPIClient, dataManager: mockDataManager, eventMapping: mockEventMapping) } - override func tearDown() { + override func tearDown() async throws { mockAPIClient = nil mockDataManager = nil mockEventMapping = nil detector = nil - super.tearDown() } func testIsMaliciousWithLocalFilterHit() async { let filter = Filter(hash: "255a8a793097aeea1f06a19c08cde28db0eb34c660c6e4e7480c9525d034b16d", regex: ".*malicious.*") - mockDataManager.filterSet = Set([filter]) - mockDataManager.hashPrefixes = Set(["255a8a79"]) + await mockDataManager.store(FilterDictionary(revision: 0, items: [filter]), for: .filterSet(threatKind: .phishing)) + await mockDataManager.store(HashPrefixSet(revision: 0, items: ["255a8a79"]), for: .hashPrefixes(threatKind: .phishing)) let url = URL(string: "https://malicious.com/")! @@ -56,8 +54,8 @@ class MaliciousSiteDetectorTests: XCTestCase { } func testIsMaliciousWithApiMatch() async { - mockDataManager.filterSet = Set() - mockDataManager.hashPrefixes = ["a379a6f6"] + await mockDataManager.store(FilterDictionary(revision: 0, items: []), for: .filterSet(threatKind: .phishing)) + await mockDataManager.store(HashPrefixSet(revision: 0, items: ["a379a6f6"]), for: .hashPrefixes(threatKind: .phishing)) let url = URL(string: "https://example.com/mal")! @@ -68,8 +66,8 @@ class MaliciousSiteDetectorTests: XCTestCase { func testIsMaliciousWithHashPrefixMatch() async { let filter = Filter(hash: "notamatch", regex: ".*malicious.*") - mockDataManager.filterSet = [filter] - mockDataManager.hashPrefixes = ["4c64eb24"] // matches safe.com + await mockDataManager.store(FilterDictionary(revision: 0, items: [filter]), for: .filterSet(threatKind: .phishing)) + await mockDataManager.store(HashPrefixSet(revision: 0, items: ["4c64eb24" /* matches safe.com */]), for: .hashPrefixes(threatKind: .phishing)) let url = URL(string: "https://safe.com")! @@ -81,8 +79,8 @@ class MaliciousSiteDetectorTests: XCTestCase { func testIsMaliciousWithFullHashMatch() async { // 4c64eb2468bcd3e113b37167e6b819aeccf550f974a6082ef17fb74ca68e823b let filter = Filter(hash: "4c64eb2468bcd3e113b37167e6b819aeccf550f974a6082ef17fb74ca68e823b", regex: "https://safe.com/maliciousURI") - mockDataManager.filterSet = [filter] - mockDataManager.hashPrefixes = ["4c64eb24"] + await mockDataManager.store(FilterDictionary(revision: 0, items: [filter]), for: .filterSet(threatKind: .phishing)) + await mockDataManager.store(HashPrefixSet(revision: 0, items: ["4c64eb24"]), for: .hashPrefixes(threatKind: .phishing)) let url = URL(string: "https://safe.com")! @@ -93,8 +91,8 @@ class MaliciousSiteDetectorTests: XCTestCase { func testIsMaliciousWithNoHashPrefixMatch() async { let filter = Filter(hash: "testHash", regex: ".*malicious.*") - mockDataManager.filterSet = [filter] - mockDataManager.hashPrefixes = ["testPrefix"] + await mockDataManager.store(FilterDictionary(revision: 0, items: [filter]), for: .filterSet(threatKind: .phishing)) + await mockDataManager.store(HashPrefixSet(revision: 0, items: ["testPrefix"]), for: .hashPrefixes(threatKind: .phishing)) let url = URL(string: "https://safe.com")! diff --git a/Tests/MaliciousSiteProtectionTests/MaliciousSiteProtectionAPIClientTests.swift b/Tests/MaliciousSiteProtectionTests/MaliciousSiteProtectionAPIClientTests.swift index d32d264ea..fcea80939 100644 --- a/Tests/MaliciousSiteProtectionTests/MaliciousSiteProtectionAPIClientTests.swift +++ b/Tests/MaliciousSiteProtectionTests/MaliciousSiteProtectionAPIClientTests.swift @@ -30,7 +30,7 @@ final class MaliciousSiteProtectionAPIClientTests: XCTestCase { override func setUp() { super.setUp() mockService = MockAPIService() - client = .init(environment: .staging, service: mockService) + client = .init(environment: MaliciousSiteDetector.APIEnvironment.staging, service: mockService) } override func tearDown() { diff --git a/Tests/MaliciousSiteProtectionTests/MaliciousSiteProtectionDataManagerTests.swift b/Tests/MaliciousSiteProtectionTests/MaliciousSiteProtectionDataManagerTests.swift index a6763c2f3..5164f78d3 100644 --- a/Tests/MaliciousSiteProtectionTests/MaliciousSiteProtectionDataManagerTests.swift +++ b/Tests/MaliciousSiteProtectionTests/MaliciousSiteProtectionDataManagerTests.swift @@ -27,21 +27,28 @@ class MaliciousSiteProtectionDataManagerTests: XCTestCase { static let hashPrefixesFileName = "phishingHashPrefixes.json" static let filterSetFileName = "phishingFilterSet.json" } - let datasetFiles: [String] = [Constants.hashPrefixesFileName, Constants.filterSetFileName, "revision.txt"] + let datasetFiles: [String] = [Constants.hashPrefixesFileName, Constants.filterSetFileName] var dataManager: MaliciousSiteProtection.DataManager! var fileStore: MaliciousSiteProtection.FileStoring! - override func setUp() { - super.setUp() + override func setUp() async throws { embeddedDataProvider = MockMaliciousSiteProtectionEmbeddedDataProvider() fileStore = MockMaliciousSiteProtectionFileStore() - dataManager = MaliciousSiteProtection.DataManager(embeddedDataProvider: embeddedDataProvider, fileStore: fileStore) + setUpDataManager() } - override func tearDown() { + func setUpDataManager() { + dataManager = MaliciousSiteProtection.DataManager(fileStore: fileStore, embeddedDataProvider: embeddedDataProvider, fileNameProvider: { dataType in + switch dataType { + case .filterSet: Constants.filterSetFileName + case .hashPrefixSet: Constants.hashPrefixesFileName + } + }) + } + + override func tearDown() async throws { embeddedDataProvider = nil dataManager = nil - super.tearDown() } func clearDatasets() { @@ -53,22 +60,22 @@ class MaliciousSiteProtectionDataManagerTests: XCTestCase { func testWhenNoDataSavedThenProviderDataReturned() async { clearDatasets() - let expectedFilerSet = Set([Filter(hash: "some", regex: "some")]) + let expectedFilterSet = Set([Filter(hash: "some", regex: "some")]) + let expectedFilterDict = FilterDictionary(revision: 65, items: expectedFilterSet) let expectedHashPrefix = Set(["sassa"]) - embeddedDataProvider.shouldReturnFilterSet(set: expectedFilerSet) - embeddedDataProvider.shouldReturnHashPrefixes(set: expectedHashPrefix) + embeddedDataProvider.filterSet = expectedFilterSet + embeddedDataProvider.hashPrefixes = expectedHashPrefix - let actualFilterSet = dataManager.filterSet - let actualHashPrefix = dataManager.hashPrefixes + let actualFilterSet = await dataManager.dataSet(for: .filterSet(threatKind: .phishing)) + let actualHashPrefix = await dataManager.dataSet(for: .hashPrefixes(threatKind: .phishing)) - XCTAssertEqual(actualFilterSet, expectedFilerSet) - XCTAssertEqual(actualHashPrefix, expectedHashPrefix) + XCTAssertEqual(actualFilterSet, expectedFilterDict) + XCTAssertEqual(actualHashPrefix.set, expectedHashPrefix) } func testWhenEmbeddedRevisionNewerThanOnDisk_ThenLoadEmbedded() async { let encoder = JSONEncoder() // On Disk Data Setup - fileStore.write(data: "1".utf8data, to: "revision.txt") let onDiskFilterSet = Set([Filter(hash: "other", regex: "other")]) let filterSetData = try! encoder.encode(Array(onDiskFilterSet)) let onDiskHashPrefix = Set(["faffa"]) @@ -79,27 +86,28 @@ class MaliciousSiteProtectionDataManagerTests: XCTestCase { // Embedded Data Setup embeddedDataProvider.embeddedRevision = 5 let embeddedFilterSet = Set([Filter(hash: "some", regex: "some")]) + let embeddedFilterDict = FilterDictionary(revision: 5, items: embeddedFilterSet) let embeddedHashPrefix = Set(["sassa"]) - embeddedDataProvider.shouldReturnFilterSet(set: embeddedFilterSet) - embeddedDataProvider.shouldReturnHashPrefixes(set: embeddedHashPrefix) - - let actualRevision = dataManager.currentRevision - let actualFilterSet = dataManager.filterSet - let actualHashPrefix = dataManager.hashPrefixes - - XCTAssertEqual(actualFilterSet, embeddedFilterSet) - XCTAssertEqual(actualHashPrefix, embeddedHashPrefix) - XCTAssertEqual(actualRevision, 5) + embeddedDataProvider.filterSet = embeddedFilterSet + embeddedDataProvider.hashPrefixes = embeddedHashPrefix + + let actualFilterSet = await dataManager.dataSet(for: .filterSet(threatKind: .phishing)) + let actualHashPrefix = await dataManager.dataSet(for: .hashPrefixes(threatKind: .phishing)) + let actualFilterSetRevision = actualFilterSet.revision + let actualHashPrefixRevision = actualFilterSet.revision + + XCTAssertEqual(actualFilterSet, embeddedFilterDict) + XCTAssertEqual(actualHashPrefix.set, embeddedHashPrefix) + XCTAssertEqual(actualFilterSetRevision, 5) + XCTAssertEqual(actualHashPrefixRevision, 5) } func testWhenEmbeddedRevisionOlderThanOnDisk_ThenDontLoadEmbedded() async { - let encoder = JSONEncoder() // On Disk Data Setup - fileStore.write(data: "6".utf8data, to: "revision.txt") - let onDiskFilterSet = Set([Filter(hash: "other", regex: "other")]) - let filterSetData = try! encoder.encode(Array(onDiskFilterSet)) - let onDiskHashPrefix = Set(["faffa"]) - let hashPrefixData = try! encoder.encode(Array(onDiskHashPrefix)) + let onDiskFilterDict = FilterDictionary(revision: 6, items: [Filter(hash: "other", regex: "other")]) + let filterSetData = try! JSONEncoder().encode(onDiskFilterDict) + let onDiskHashPrefix = HashPrefixSet(revision: 6, items: ["faffa"]) + let hashPrefixData = try! JSONEncoder().encode(onDiskHashPrefix) fileStore.write(data: filterSetData, to: Constants.filterSetFileName) fileStore.write(data: hashPrefixData, to: Constants.hashPrefixesFileName) @@ -107,16 +115,42 @@ class MaliciousSiteProtectionDataManagerTests: XCTestCase { embeddedDataProvider.embeddedRevision = 1 let embeddedFilterSet = Set([Filter(hash: "some", regex: "some")]) let embeddedHashPrefix = Set(["sassa"]) - embeddedDataProvider.shouldReturnFilterSet(set: embeddedFilterSet) - embeddedDataProvider.shouldReturnHashPrefixes(set: embeddedHashPrefix) + embeddedDataProvider.filterSet = embeddedFilterSet + embeddedDataProvider.hashPrefixes = embeddedHashPrefix - let actualRevision = dataManager.currentRevision - let actualFilterSet = dataManager.filterSet - let actualHashPrefix = dataManager.hashPrefixes + let actualFilterSet = await dataManager.dataSet(for: .filterSet(threatKind: .phishing)) + let actualHashPrefix = await dataManager.dataSet(for: .hashPrefixes(threatKind: .phishing)) + let actualFilterSetRevision = actualFilterSet.revision + let actualHashPrefixRevision = actualFilterSet.revision - XCTAssertEqual(actualFilterSet, onDiskFilterSet) + XCTAssertEqual(actualFilterSet, onDiskFilterDict) XCTAssertEqual(actualHashPrefix, onDiskHashPrefix) - XCTAssertEqual(actualRevision, 6) + XCTAssertEqual(actualFilterSetRevision, 6) + XCTAssertEqual(actualHashPrefixRevision, 6) + } + + func testWhenStoredDataIsMalformed_ThenEmbeddedDataIsLoaded() async { + // On Disk Data Setup + fileStore.write(data: "fake".utf8data, to: Constants.filterSetFileName) + fileStore.write(data: "fake".utf8data, to: Constants.hashPrefixesFileName) + + // Embedded Data Setup + embeddedDataProvider.embeddedRevision = 1 + let embeddedFilterSet = Set([Filter(hash: "some", regex: "some")]) + let embeddedFilterDict = FilterDictionary(revision: 1, items: embeddedFilterSet) + let embeddedHashPrefix = Set(["sassa"]) + embeddedDataProvider.filterSet = embeddedFilterSet + embeddedDataProvider.hashPrefixes = embeddedHashPrefix + + let actualFilterSet = await dataManager.dataSet(for: .filterSet(threatKind: .phishing)) + let actualHashPrefix = await dataManager.dataSet(for: .hashPrefixes(threatKind: .phishing)) + let actualFilterSetRevision = actualFilterSet.revision + let actualHashPrefixRevision = actualFilterSet.revision + + XCTAssertEqual(actualFilterSet, embeddedFilterDict) + XCTAssertEqual(actualHashPrefix.set, embeddedHashPrefix) + XCTAssertEqual(actualFilterSetRevision, 1) + XCTAssertEqual(actualHashPrefixRevision, 1) } func testWriteAndLoadData() async { @@ -125,31 +159,31 @@ class MaliciousSiteProtectionDataManagerTests: XCTestCase { let expectedFilterSet = Set([Filter(hash: "dummyhash", regex: "dummyregex")]) let expectedRevision = 65 - dataManager.saveHashPrefixes(set: expectedHashPrefixes) - dataManager.saveFilterSet(set: expectedFilterSet) - dataManager.saveRevision(expectedRevision) - - XCTAssertEqual(dataManager.filterSet, expectedFilterSet) - XCTAssertEqual(dataManager.hashPrefixes, expectedHashPrefixes) - XCTAssertEqual(dataManager.currentRevision, expectedRevision) - - // Test decode JSON data to expected types - let storedHashPrefixesData = fileStore.read(from: Constants.hashPrefixesFileName) - let storedFilterSetData = fileStore.read(from: Constants.filterSetFileName) - let storedRevisionData = fileStore.read(from: "revision.txt") - - let decoder = JSONDecoder() - if let storedHashPrefixes = try? decoder.decode(Set.self, from: storedHashPrefixesData!), - let storedFilterSet = try? decoder.decode(Set.self, from: storedFilterSetData!), - let storedRevisionString = String(data: storedRevisionData!, encoding: .utf8), - let storedRevision = Int(storedRevisionString.trimmingCharacters(in: .whitespacesAndNewlines)) { - - XCTAssertEqual(storedFilterSet, expectedFilterSet) - XCTAssertEqual(storedHashPrefixes, expectedHashPrefixes) - XCTAssertEqual(storedRevision, expectedRevision) - } else { - XCTFail("Failed to decode stored PhishingDetection data") - } + await dataManager.store(HashPrefixSet(revision: expectedRevision, items: expectedHashPrefixes), for: .hashPrefixes(threatKind: .phishing)) + await dataManager.store(FilterDictionary(revision: expectedRevision, items: expectedFilterSet), for: .filterSet(threatKind: .phishing)) + + let actualFilterSet = await dataManager.dataSet(for: .filterSet(threatKind: .phishing)) + let actualHashPrefix = await dataManager.dataSet(for: .hashPrefixes(threatKind: .phishing)) + let actualFilterSetRevision = actualFilterSet.revision + let actualHashPrefixRevision = actualFilterSet.revision + + XCTAssertEqual(actualFilterSet, FilterDictionary(revision: expectedRevision, items: expectedFilterSet)) + XCTAssertEqual(actualHashPrefix.set, expectedHashPrefixes) + XCTAssertEqual(actualFilterSetRevision, 65) + XCTAssertEqual(actualHashPrefixRevision, 65) + + // Test reloading data + setUpDataManager() + + let reloadedFilterSet = await dataManager.dataSet(for: .filterSet(threatKind: .phishing)) + let reloadedHashPrefix = await dataManager.dataSet(for: .hashPrefixes(threatKind: .phishing)) + let reloadedFilterSetRevision = actualFilterSet.revision + let reloadedHashPrefixRevision = actualFilterSet.revision + + XCTAssertEqual(reloadedFilterSet, FilterDictionary(revision: expectedRevision, items: expectedFilterSet)) + XCTAssertEqual(reloadedHashPrefix.set, expectedHashPrefixes) + XCTAssertEqual(reloadedFilterSetRevision, 65) + XCTAssertEqual(reloadedHashPrefixRevision, 65) } func testLazyLoadingDoesNotReturnStaleData() async { @@ -158,53 +192,56 @@ class MaliciousSiteProtectionDataManagerTests: XCTestCase { // Set up initial data let initialFilterSet = Set([Filter(hash: "initial", regex: "initial")]) let initialHashPrefixes = Set(["initialPrefix"]) - embeddedDataProvider.shouldReturnFilterSet(set: initialFilterSet) - embeddedDataProvider.shouldReturnHashPrefixes(set: initialHashPrefixes) + embeddedDataProvider.filterSet = initialFilterSet + embeddedDataProvider.hashPrefixes = initialHashPrefixes // Access the lazy-loaded properties to trigger loading - let loadedFilterSet = dataManager.filterSet - let loadedHashPrefixes = dataManager.hashPrefixes + let loadedFilterSet = await dataManager.dataSet(for: .filterSet(threatKind: .phishing)) + let loadedHashPrefixes = await dataManager.dataSet(for: .hashPrefixes(threatKind: .phishing)) // Validate loaded data matches initial data - XCTAssertEqual(loadedFilterSet, initialFilterSet) - XCTAssertEqual(loadedHashPrefixes, initialHashPrefixes) + XCTAssertEqual(loadedFilterSet, FilterDictionary(revision: 65, items: initialFilterSet)) + XCTAssertEqual(loadedHashPrefixes.set, initialHashPrefixes) // Update in-memory data let updatedFilterSet = Set([Filter(hash: "updated", regex: "updated")]) let updatedHashPrefixes = Set(["updatedPrefix"]) - dataManager.saveFilterSet(set: updatedFilterSet) - dataManager.saveHashPrefixes(set: updatedHashPrefixes) - - // Access lazy-loaded properties again - let reloadedFilterSet = dataManager.filterSet - let reloadedHashPrefixes = dataManager.hashPrefixes - - // Validate reloaded data matches updated data - XCTAssertEqual(reloadedFilterSet, updatedFilterSet) - XCTAssertEqual(reloadedHashPrefixes, updatedHashPrefixes) - - // Validate on-disk data is also updated - let storedFilterSetData = fileStore.read(from: Constants.filterSetFileName) - let storedHashPrefixesData = fileStore.read(from: Constants.hashPrefixesFileName) - - let decoder = JSONDecoder() - if let storedFilterSet = try? decoder.decode(Set.self, from: storedFilterSetData!), - let storedHashPrefixes = try? decoder.decode(Set.self, from: storedHashPrefixesData!) { - - XCTAssertEqual(storedFilterSet, updatedFilterSet) - XCTAssertEqual(storedHashPrefixes, updatedHashPrefixes) - } else { - XCTFail("Failed to decode stored PhishingDetection data after update") - } + await dataManager.store(HashPrefixSet(revision: 1, items: updatedHashPrefixes), for: .hashPrefixes(threatKind: .phishing)) + await dataManager.store(FilterDictionary(revision: 1, items: updatedFilterSet), for: .filterSet(threatKind: .phishing)) + + let actualFilterSet = await dataManager.dataSet(for: .filterSet(threatKind: .phishing)) + let actualHashPrefix = await dataManager.dataSet(for: .hashPrefixes(threatKind: .phishing)) + let actualFilterSetRevision = actualFilterSet.revision + let actualHashPrefixRevision = actualFilterSet.revision + + XCTAssertEqual(actualFilterSet, FilterDictionary(revision: 1, items: updatedFilterSet)) + XCTAssertEqual(actualHashPrefix.set, updatedHashPrefixes) + XCTAssertEqual(actualFilterSetRevision, 1) + XCTAssertEqual(actualHashPrefixRevision, 1) + + // Test reloading data – embedded data should be returned as its revision is greater + setUpDataManager() + + let reloadedFilterSet = await dataManager.dataSet(for: .filterSet(threatKind: .phishing)) + let reloadedHashPrefix = await dataManager.dataSet(for: .hashPrefixes(threatKind: .phishing)) + let reloadedFilterSetRevision = actualFilterSet.revision + let reloadedHashPrefixRevision = actualFilterSet.revision + + XCTAssertEqual(reloadedFilterSet, FilterDictionary(revision: 65, items: initialFilterSet)) + XCTAssertEqual(reloadedHashPrefix.set, initialHashPrefixes) + XCTAssertEqual(reloadedFilterSetRevision, 1) + XCTAssertEqual(reloadedHashPrefixRevision, 1) } } class MockMaliciousSiteProtectionFileStore: MaliciousSiteProtection.FileStoring { + private var data: [String: Data] = [:] - func write(data: Data, to filename: String) { + func write(data: Data, to filename: String) -> Bool { self.data[filename] = data + return true } func read(from filename: String) -> Data? { diff --git a/Tests/MaliciousSiteProtectionTests/MaliciousSiteProtectionEmbeddedDataProviderTest.swift b/Tests/MaliciousSiteProtectionTests/MaliciousSiteProtectionEmbeddedDataProviderTest.swift index 6246e1f93..1e3e0df40 100644 --- a/Tests/MaliciousSiteProtectionTests/MaliciousSiteProtectionEmbeddedDataProviderTest.swift +++ b/Tests/MaliciousSiteProtectionTests/MaliciousSiteProtectionEmbeddedDataProviderTest.swift @@ -21,28 +21,42 @@ import XCTest @testable import MaliciousSiteProtection class MaliciousSiteProtectionEmbeddedDataProviderTest: XCTestCase { - var filterSetURL: URL! - var hashPrefixURL: URL! - var dataProvider: MaliciousSiteProtection.EmbeddedDataProvider! - override func setUp() { - super.setUp() - filterSetURL = Bundle.module.url(forResource: "phishingFilterSet", withExtension: "json")! - hashPrefixURL = Bundle.module.url(forResource: "phishingHashPrefixes", withExtension: "json")! - } + struct TestEmbeddedDataProvider: MaliciousSiteProtection.EmbeddedDataProviding { + func revision(for dataType: MaliciousSiteProtection.DataManager.StoredDataType) -> Int { + 0 + } + + func url(for dataType: MaliciousSiteProtection.DataManager.StoredDataType) -> URL { + switch dataType { + case .filterSet(let key): + Bundle.module.url(forResource: "\(key.threatKind)FilterSet", withExtension: "json")! + case .hashPrefixSet(let key): + Bundle.module.url(forResource: "\(key.threatKind)HashPrefixes", withExtension: "json")! + } + } - override func tearDown() { - filterSetURL = nil - hashPrefixURL = nil - dataProvider = nil - super.tearDown() + func hash(for dataType: MaliciousSiteProtection.DataManager.StoredDataType) -> String { + switch dataType { + case .filterSet(let key): + switch key.threatKind { + case .phishing: + "4fd2868a4f264501ec175ab866504a2a96c8d21a3b5195b405a4a83b51eae504" + } + case .hashPrefixSet(let key): + switch key.threatKind { + case .phishing: + "21b047a9950fcaf86034a6b16181e18815cb8d276386d85c8977ca8c5f8aa05f" + } + } + } } func testDataProviderLoadsJSON() { - dataProvider = .init(revision: 0, filterSetURL: filterSetURL, filterSetDataSHA: "4fd2868a4f264501ec175ab866504a2a96c8d21a3b5195b405a4a83b51eae504", hashPrefixURL: hashPrefixURL, hashPrefixDataSHA: "21b047a9950fcaf86034a6b16181e18815cb8d276386d85c8977ca8c5f8aa05f") + let dataProvider = TestEmbeddedDataProvider() let expectedFilter = Filter(hash: "e4753ddad954dafd4ff4ef67f82b3c1a2db6ef4a51bda43513260170e558bd13", regex: "(?i)^https?\\:\\/\\/privacy-test-pages\\.site(?:\\:(?:80|443))?\\/security\\/badware\\/phishing\\.html$") - XCTAssertTrue(dataProvider.loadEmbeddedFilterSet().contains(expectedFilter)) - XCTAssertTrue(dataProvider.loadEmbeddedHashPrefixes().contains("012db806")) + XCTAssertTrue(dataProvider.loadDataSet(for: .filterSet(threatKind: .phishing)).contains(expectedFilter)) + XCTAssertTrue(dataProvider.loadDataSet(for: .hashPrefixes(threatKind: .phishing)).contains("012db806")) } } diff --git a/Tests/MaliciousSiteProtectionTests/MaliciousSiteProtectionUpdateManagerTests.swift b/Tests/MaliciousSiteProtectionTests/MaliciousSiteProtectionUpdateManagerTests.swift index 53d68dbd5..8d46d5cf7 100644 --- a/Tests/MaliciousSiteProtectionTests/MaliciousSiteProtectionUpdateManagerTests.swift +++ b/Tests/MaliciousSiteProtectionTests/MaliciousSiteProtectionUpdateManagerTests.swift @@ -16,51 +16,64 @@ // limitations under the License. // +import Clocks +import Common import Foundation import XCTest @testable import MaliciousSiteProtection class MaliciousSiteProtectionUpdateManagerTests: XCTestCase { + var updateManager: MaliciousSiteProtection.UpdateManager! - var dataManager: MaliciousSiteProtection.DataManaging! - var apiClient: MaliciousSiteProtection.APIClientProtocol! + var dataManager: MockMaliciousSiteProtectionDataManager! + var apiClient: MaliciousSiteProtection.APIClient.Mockable! + var updateIntervalProvider: UpdateManager.UpdateIntervalProvider! + var clock: TestClock! + var willSleep: ((TimeInterval) -> Void)? + var updateTask: Task? override func setUp() async throws { - try await super.setUp() apiClient = MockMaliciousSiteProtectionAPIClient() dataManager = MockMaliciousSiteProtectionDataManager() - updateManager = MaliciousSiteProtection.UpdateManager(apiClient: apiClient, dataManager: dataManager) - dataManager.saveRevision(0) - await updateManager.updateFilterSet() - await updateManager.updateHashPrefixes() + clock = TestClock() + + let clockSleeper = Sleeper(clock: clock) + let reportingSleeper = Sleeper { + self.willSleep?($0) + try await clockSleeper.sleep(for: $0) + } + + updateManager = MaliciousSiteProtection.UpdateManager(apiClient: apiClient, dataManager: dataManager, sleeper: reportingSleeper, updateIntervalProvider: { self.updateIntervalProvider($0) }) } - override func tearDown() { + override func tearDown() async throws { updateManager = nil dataManager = nil apiClient = nil - super.tearDown() + updateIntervalProvider = nil + updateTask?.cancel() } func testUpdateHashPrefixes() async { - await updateManager.updateHashPrefixes() - XCTAssertFalse(dataManager.hashPrefixes.isEmpty, "Hash prefixes should not be empty after update.") - XCTAssertEqual(dataManager.hashPrefixes, [ + await updateManager.updateData(for: .hashPrefixes(threatKind: .phishing)) + let dataSet = await dataManager.dataSet(for: .hashPrefixes(threatKind: .phishing)) + XCTAssertEqual(dataSet, HashPrefixSet(revision: 1, items: [ "aa00bb11", "bb00cc11", "cc00dd11", "dd00ee11", "a379a6f6" - ]) + ])) } func testUpdateFilterSet() async { - await updateManager.updateFilterSet() - XCTAssertEqual(dataManager.filterSet, [ + await updateManager.updateData(for: .filterSet(threatKind: .phishing)) + let dataSet = await dataManager.dataSet(for: .filterSet(threatKind: .phishing)) + XCTAssertEqual(dataSet, FilterDictionary(revision: 1, items: [ Filter(hash: "testhash1", regex: ".*example.*"), Filter(hash: "testhash2", regex: ".*test.*") - ]) + ])) } func testRevision1AddsAndDeletesData() async { @@ -75,19 +88,27 @@ class MaliciousSiteProtectionUpdateManagerTests: XCTestCase { "93e2435e" ] - // Save revision and update the filter set and hash prefixes - dataManager.saveRevision(1) - await updateManager.updateFilterSet() - await updateManager.updateHashPrefixes() + // revision 0 -> 1 + await updateManager.updateData(for: .filterSet(threatKind: .phishing)) + await updateManager.updateData(for: .hashPrefixes(threatKind: .phishing)) + + // revision 1 -> 2 + await updateManager.updateData(for: .filterSet(threatKind: .phishing)) + await updateManager.updateData(for: .hashPrefixes(threatKind: .phishing)) - XCTAssertEqual(dataManager.filterSet, expectedFilterSet, "Filter set should match the expected set after update.") - XCTAssertEqual(dataManager.hashPrefixes, expectedHashPrefixes, "Hash prefixes should match the expected set after update.") + let hashPrefixes = await dataManager.dataSet(for: .hashPrefixes(threatKind: .phishing)) + let filterSet = await dataManager.dataSet(for: .filterSet(threatKind: .phishing)) + + XCTAssertEqual(hashPrefixes, HashPrefixSet(revision: 2, items: expectedHashPrefixes), "Hash prefixes should match the expected set after update.") + XCTAssertEqual(filterSet, FilterDictionary(revision: 2, items: expectedFilterSet), "Filter set should match the expected set after update.") } func testRevision2AddsAndDeletesData() async { let expectedFilterSet: Set = [ Filter(hash: "testhash4", regex: ".*test.*"), - Filter(hash: "testhash1", regex: ".*example.*") + Filter(hash: "testhash2", regex: ".*test1.*"), + Filter(hash: "testhash1", regex: ".*example.*"), + Filter(hash: "testhash3", regex: ".*test3.*"), ] let expectedHashPrefixes: Set = [ "aa00bb11", @@ -98,46 +119,274 @@ class MaliciousSiteProtectionUpdateManagerTests: XCTestCase { ] // Save revision and update the filter set and hash prefixes - dataManager.saveRevision(2) - await updateManager.updateFilterSet() - await updateManager.updateHashPrefixes() + await dataManager.store(FilterDictionary(revision: 2, items: [ + Filter(hash: "testhash1", regex: ".*example.*"), + Filter(hash: "testhash2", regex: ".*test.*"), + Filter(hash: "testhash2", regex: ".*test1.*"), + Filter(hash: "testhash3", regex: ".*test3.*"), + ]), for: .filterSet(threatKind: .phishing)) + await dataManager.store(HashPrefixSet(revision: 2, items: [ + "aa00bb11", + "bb00cc11", + "cc00dd11", + "dd00ee11", + "a379a6f6" + ]), for: .hashPrefixes(threatKind: .phishing)) + + await updateManager.updateData(for: .filterSet(threatKind: .phishing)) + await updateManager.updateData(for: .hashPrefixes(threatKind: .phishing)) - XCTAssertEqual(dataManager.filterSet, expectedFilterSet, "Filter set should match the expected set after update.") - XCTAssertEqual(dataManager.hashPrefixes, expectedHashPrefixes, "Hash prefixes should match the expected set after update.") + let hashPrefixes = await dataManager.dataSet(for: .hashPrefixes(threatKind: .phishing)) + let filterSet = await dataManager.dataSet(for: .filterSet(threatKind: .phishing)) + + XCTAssertEqual(hashPrefixes, HashPrefixSet(revision: 3, items: expectedHashPrefixes), "Hash prefixes should match the expected set after update.") + XCTAssertEqual(filterSet, FilterDictionary(revision: 3, items: expectedFilterSet), "Filter set should match the expected set after update.") } func testRevision3AddsAndDeletesNothing() async { - let expectedFilterSet = dataManager.filterSet - let expectedHashPrefixes = dataManager.hashPrefixes + let expectedFilterSet: Set = [] + let expectedHashPrefixes: Set = [] // Save revision and update the filter set and hash prefixes - dataManager.saveRevision(3) - await updateManager.updateFilterSet() - await updateManager.updateHashPrefixes() + await dataManager.store(FilterDictionary(revision: 3, items: []), for: .filterSet(threatKind: .phishing)) + await dataManager.store(HashPrefixSet(revision: 3, items: []), for: .hashPrefixes(threatKind: .phishing)) + + await updateManager.updateData(for: .filterSet(threatKind: .phishing)) + await updateManager.updateData(for: .hashPrefixes(threatKind: .phishing)) - XCTAssertEqual(dataManager.filterSet, expectedFilterSet, "Filter set should match the expected set after update.") - XCTAssertEqual(dataManager.hashPrefixes, expectedHashPrefixes, "Hash prefixes should match the expected set after update.") + let hashPrefixes = await dataManager.dataSet(for: .hashPrefixes(threatKind: .phishing)) + let filterSet = await dataManager.dataSet(for: .filterSet(threatKind: .phishing)) + + XCTAssertEqual(hashPrefixes, HashPrefixSet(revision: 3, items: expectedHashPrefixes), "Hash prefixes should match the expected set after update.") + XCTAssertEqual(filterSet, FilterDictionary(revision: 3, items: expectedFilterSet), "Filter set should match the expected set after update.") } func testRevision4AddsAndDeletesData() async { let expectedFilterSet: Set = [ - Filter(hash: "testhash2", regex: ".*test.*"), - Filter(hash: "testhash1", regex: ".*example.*"), Filter(hash: "testhash5", regex: ".*test.*") ] let expectedHashPrefixes: Set = [ "a379a6f6", + ] + + // Save revision and update the filter set and hash prefixes + await dataManager.store(FilterDictionary(revision: 4, items: []), for: .filterSet(threatKind: .phishing)) + await dataManager.store(HashPrefixSet(revision: 4, items: []), for: .hashPrefixes(threatKind: .phishing)) + + await updateManager.updateData(for: .filterSet(threatKind: .phishing)) + await updateManager.updateData(for: .hashPrefixes(threatKind: .phishing)) + + let hashPrefixes = await dataManager.dataSet(for: .hashPrefixes(threatKind: .phishing)) + let filterSet = await dataManager.dataSet(for: .filterSet(threatKind: .phishing)) + + XCTAssertEqual(hashPrefixes, HashPrefixSet(revision: 5, items: expectedHashPrefixes), "Hash prefixes should match the expected set after update.") + XCTAssertEqual(filterSet, FilterDictionary(revision: 5, items: expectedFilterSet), "Filter set should match the expected set after update.") + } + + func testRevision5replacesData() async { + let expectedFilterSet: Set = [ + Filter(hash: "testhash6", regex: ".*test6.*") + ] + let expectedHashPrefixes: Set = [ + "aa55aa55" + ] + + // Save revision and update the filter set and hash prefixes + await dataManager.store(FilterDictionary(revision: 5, items: [ + Filter(hash: "testhash2", regex: ".*test.*"), + Filter(hash: "testhash1", regex: ".*example.*"), + Filter(hash: "testhash5", regex: ".*test.*") + ]), for: .filterSet(threatKind: .phishing)) + await dataManager.store(HashPrefixSet(revision: 5, items: [ + "a379a6f6", "dd00ee11", "cc00dd11", "bb00cc11" + ]), for: .hashPrefixes(threatKind: .phishing)) + + await updateManager.updateData(for: .filterSet(threatKind: .phishing)) + await updateManager.updateData(for: .hashPrefixes(threatKind: .phishing)) + + let hashPrefixes = await dataManager.dataSet(for: .hashPrefixes(threatKind: .phishing)) + let filterSet = await dataManager.dataSet(for: .filterSet(threatKind: .phishing)) + + XCTAssertEqual(hashPrefixes, HashPrefixSet(revision: 6, items: expectedHashPrefixes), "Hash prefixes should match the expected set after update.") + XCTAssertEqual(filterSet, FilterDictionary(revision: 6, items: expectedFilterSet), "Filter set should match the expected set after update.") + } + + func testWhenPeriodicUpdatesStart_dataSetsAreUpdated() async throws { + self.updateIntervalProvider = { _ in 1 } + + let eHashPrefixesUpdated = expectation(description: "Hash prefixes updated") + let c1 = await dataManager.publisher(for: .hashPrefixes(threatKind: .phishing)).dropFirst().sink { data in + eHashPrefixesUpdated.fulfill() + } + let eFilterSetUpdated = expectation(description: "Filter set updated") + let c2 = await dataManager.publisher(for: .filterSet(threatKind: .phishing)).dropFirst().sink { data in + eFilterSetUpdated.fulfill() + } + + updateTask = updateManager.startPeriodicUpdates() + await Task.megaYield(count: 10) + + // expect initial update run instantly + await fulfillment(of: [eHashPrefixesUpdated, eFilterSetUpdated], timeout: 1) + + withExtendedLifetime((c1, c2)) {} + } + + func testWhenPeriodicUpdatesAreEnabled_dataSetsAreUpdatedContinuously() async throws { + // Start periodic updates + self.updateIntervalProvider = { dataType in + switch dataType { + case .filterSet: return 2 + case .hashPrefixSet: return 1 + } + } + + let hashPrefixUpdateExpectations = [ + XCTestExpectation(description: "Hash prefixes rev.1 update received"), + XCTestExpectation(description: "Hash prefixes rev.2 update received"), + XCTestExpectation(description: "Hash prefixes rev.3 update received"), + ] + let filterSetUpdateExpectations = [ + XCTestExpectation(description: "Filter set rev.1 update received"), + XCTestExpectation(description: "Filter set rev.2 update received"), + XCTestExpectation(description: "Filter set rev.3 update received"), + ] + let hashPrefixSleepExpectations = [ + XCTestExpectation(description: "HP Will Sleep 1"), + XCTestExpectation(description: "HP Will Sleep 2"), + XCTestExpectation(description: "HP Will Sleep 3"), + ] + let filterSetSleepExpectations = [ + XCTestExpectation(description: "FS Will Sleep 1"), + XCTestExpectation(description: "FS Will Sleep 2"), + XCTestExpectation(description: "FS Will Sleep 3"), ] - // Save revision and update the filter set and hash prefixes - dataManager.saveRevision(4) - await updateManager.updateFilterSet() - await updateManager.updateHashPrefixes() + let c1 = await dataManager.publisher(for: .hashPrefixes(threatKind: .phishing)).dropFirst().sink { data in + hashPrefixUpdateExpectations[data.revision - 1].fulfill() + } + let c2 = await dataManager.publisher(for: .filterSet(threatKind: .phishing)).dropFirst().sink { data in + filterSetUpdateExpectations[data.revision - 1].fulfill() + } + var hashPrefixSleepIndex = 0 + var filterSetSleepIndex = 0 + self.willSleep = { interval in + if interval == 1 { + hashPrefixSleepExpectations[safe: hashPrefixSleepIndex]?.fulfill() + hashPrefixSleepIndex += 1 + } else { + filterSetSleepExpectations[safe: filterSetSleepIndex]?.fulfill() + filterSetSleepIndex += 1 + } + } + + // expect initial hashPrefixes update run instantly + updateTask = updateManager.startPeriodicUpdates() + await fulfillment(of: [hashPrefixUpdateExpectations[0], hashPrefixSleepExpectations[0], filterSetUpdateExpectations[0], filterSetSleepExpectations[0]], timeout: 1) + + // Advance the clock by 1 seconds + await self.clock.advance(by: .seconds(1)) + // expect to receive v.2 update for hashPrefixes + await fulfillment(of: [hashPrefixUpdateExpectations[1], hashPrefixSleepExpectations[1]], timeout: 1) + + // Advance the clock by 1 seconds + await self.clock.advance(by: .seconds(1)) + // expect to receive v.3 update for hashPrefixes and v.2 update for filterSet + await fulfillment(of: [hashPrefixUpdateExpectations[2], hashPrefixSleepExpectations[2], filterSetUpdateExpectations[1], filterSetSleepExpectations[1]], timeout: 1) // - XCTAssertEqual(dataManager.filterSet, expectedFilterSet, "Filter set should match the expected set after update.") - XCTAssertEqual(dataManager.hashPrefixes, expectedHashPrefixes, "Hash prefixes should match the expected set after update.") + // Advance the clock by 1 seconds + await self.clock.advance(by: .seconds(2)) + // expect to receive v.3 update for filterSet and no update for hashPrefixes (no v.3 updates in the mock) + await fulfillment(of: [filterSetUpdateExpectations[2], filterSetSleepExpectations[2]], timeout: 1) // + + withExtendedLifetime((c1, c2)) {} + } + + func testWhenPeriodicUpdatesAreDisabled_noDataSetsAreUpdated() async throws { + // Start periodic updates + self.updateIntervalProvider = { dataType in + switch dataType { + case .filterSet: return nil // Set update interval to nil for FilterSet + case .hashPrefixSet: return 1 + } + } + + let expectations = [ + XCTestExpectation(description: "Hash prefixes rev.1 update received"), + XCTestExpectation(description: "Hash prefixes rev.2 update received"), + XCTestExpectation(description: "Hash prefixes rev.3 update received"), + ] + let c1 = await dataManager.publisher(for: .hashPrefixes(threatKind: .phishing)).dropFirst().sink { data in + expectations[data.revision - 1].fulfill() + } + // data for FilterSet should not be updated + let c2 = await dataManager.publisher(for: .filterSet(threatKind: .phishing)).dropFirst().sink { data in + XCTFail("Unexpected filter set update received: \(data)") + } + // synchronize Task threads to advance the Test Clock when the updated Task is sleeping, + // otherwise we‘ll eventually advance the clock before the sleep and get hung. + var sleepIndex = 0 + let sleepExpectations = [ + XCTestExpectation(description: "Will Sleep 1"), + XCTestExpectation(description: "Will Sleep 2"), + XCTestExpectation(description: "Will Sleep 3"), + ] + self.willSleep = { _ in + sleepExpectations[sleepIndex].fulfill() + sleepIndex += 1 + } + + // expect initial hashPrefixes update run instantly + updateTask = updateManager.startPeriodicUpdates() + await fulfillment(of: [expectations[0], sleepExpectations[0]], timeout: 1) + + // Advance the clock by 1 seconds + await self.clock.advance(by: .seconds(1)) + // expect to receive v.2 update for hashPrefixes + await fulfillment(of: [expectations[1], sleepExpectations[1]], timeout: 1) + + // Advance the clock by 1 seconds + await self.clock.advance(by: .seconds(1)) + // expect to receive v.3 update for hashPrefixes + await fulfillment(of: [expectations[2], sleepExpectations[2]], timeout: 1) + + withExtendedLifetime((c1, c2)) {} } + + func testWhenPeriodicUpdatesAreCancelled_noFurtherUpdatesReceived() async throws { + // Start periodic updates + self.updateIntervalProvider = { _ in 1 } + updateTask = updateManager.startPeriodicUpdates() + + // Wait for the initial update + try await withTimeout(1) { [self] in + for await _ in await dataManager.publisher(for: .filterSet(threatKind: .phishing)).first(where: { $0.revision == 1 }).values {} + for await _ in await dataManager.publisher(for: .filterSet(threatKind: .phishing)).first(where: { $0.revision == 1 }).values {} + } + + // Cancel the update task + updateTask!.cancel() + + // Reset expectations for further updates + let c = await dataManager.$store.dropFirst().sink { data in + XCTFail("Unexpected data update received: \(data)") + } + + // Advance the clock to check for further updates + await self.clock.advance(by: .seconds(2)) + await clock.run() + await Task.megaYield(count: 10) + + // Verify that the data sets have not been updated further + let hashPrefixes = await dataManager.dataSet(for: .hashPrefixes(threatKind: .phishing)) + let filterSet = await dataManager.dataSet(for: .filterSet(threatKind: .phishing)) + XCTAssertEqual(hashPrefixes.revision, 1) // Expecting revision to remain 1 + XCTAssertEqual(filterSet.revision, 1) // Expecting revision to remain 1 + + withExtendedLifetime(c) {} + } + } diff --git a/Tests/MaliciousSiteProtectionTests/Mocks/MockMaliciousSiteProtectionAPIClient.swift b/Tests/MaliciousSiteProtectionTests/Mocks/MockMaliciousSiteProtectionAPIClient.swift index 24d1c203b..4f2062edd 100644 --- a/Tests/MaliciousSiteProtectionTests/Mocks/MockMaliciousSiteProtectionAPIClient.swift +++ b/Tests/MaliciousSiteProtectionTests/Mocks/MockMaliciousSiteProtectionAPIClient.swift @@ -17,32 +17,37 @@ // import Foundation -import MaliciousSiteProtection +@testable import MaliciousSiteProtection -public class MockMaliciousSiteProtectionAPIClient: MaliciousSiteProtection.APIClientProtocol { - public var updateHashPrefixesWasCalled: Bool = false - public var updateFilterSetsWasCalled: Bool = false +class MockMaliciousSiteProtectionAPIClient: MaliciousSiteProtection.APIClient.Mockable { + var updateHashPrefixesCalled: ((Int) -> Void)? + var updateFilterSetsCalled: ((Int) -> Void)? - private var filterRevisions: [Int: APIClient.Response.FiltersChangeSet] = [ + var filterRevisions: [Int: APIClient.Response.FiltersChangeSet] = [ 0: .init(insert: [ Filter(hash: "testhash1", regex: ".*example.*"), Filter(hash: "testhash2", regex: ".*test.*") - ], delete: [], revision: 0, replace: true), + ], delete: [], revision: 1, replace: false), 1: .init(insert: [ Filter(hash: "testhash3", regex: ".*test.*") ], delete: [ Filter(hash: "testhash1", regex: ".*example.*"), - ], revision: 1, replace: false), + ], revision: 2, replace: false), 2: .init(insert: [ Filter(hash: "testhash4", regex: ".*test.*") ], delete: [ Filter(hash: "testhash2", regex: ".*test.*"), - ], revision: 2, replace: false), + ], revision: 3, replace: false), 4: .init(insert: [ Filter(hash: "testhash5", regex: ".*test.*") ], delete: [ Filter(hash: "testhash3", regex: ".*test.*"), - ], revision: 4, replace: false) + ], revision: 5, replace: false), + 5: .init(insert: [ + Filter(hash: "testhash6", regex: ".*test6.*") + ], delete: [ + Filter(hash: "testhash3", regex: ".*test.*"), + ], revision: 6, replace: true), ] private var hashPrefixRevisions: [Int: APIClient.Response.HashPrefixesChangeSet] = [ @@ -52,42 +57,45 @@ public class MockMaliciousSiteProtectionAPIClient: MaliciousSiteProtection.APICl "cc00dd11", "dd00ee11", "a379a6f6" - ], delete: [], revision: 0, replace: true), + ], delete: [], revision: 1, replace: false), 1: .init(insert: ["93e2435e"], delete: [ "cc00dd11", "dd00ee11", - ], revision: 1, replace: false), + ], revision: 2, replace: false), 2: .init(insert: ["c0be0d0a6"], delete: [ "bb00cc11", - ], revision: 2, replace: false), + ], revision: 3, replace: false), 4: .init(insert: ["a379a6f6"], delete: [ "aa00bb11", - ], revision: 4, replace: false) + ], revision: 5, replace: false), + 5: .init(insert: ["aa55aa55"], delete: [ + "ffgghhzz", + ], revision: 6, replace: true), ] - public func load(_ requestConfig: Request) async throws -> Request.ResponseType where Request: APIRequestProtocol { + func load(_ requestConfig: Request) async throws -> Request.Response where Request: APIClient.Request { switch requestConfig.requestType { case .hashPrefixSet(let configuration): - return _hashPrefixesChangeSet(for: configuration.threatKind, revision: configuration.revision ?? 0) as! Request.ResponseType + return _hashPrefixesChangeSet(for: configuration.threatKind, revision: configuration.revision ?? 0) as! Request.Response case .filterSet(let configuration): - return _filtersChangeSet(for: configuration.threatKind, revision: configuration.revision ?? 0) as! Request.ResponseType + return _filtersChangeSet(for: configuration.threatKind, revision: configuration.revision ?? 0) as! Request.Response case .matches(let configuration): - return _matches(forHashPrefix: configuration.hashPrefix) as! Request.ResponseType + return _matches(forHashPrefix: configuration.hashPrefix) as! Request.Response } } func _filtersChangeSet(for threatKind: MaliciousSiteProtection.ThreatKind, revision: Int) -> MaliciousSiteProtection.APIClient.Response.FiltersChangeSet { - updateFilterSetsWasCalled = true + updateFilterSetsCalled?(revision) return filterRevisions[revision] ?? .init(insert: [], delete: [], revision: revision, replace: false) } func _hashPrefixesChangeSet(for threatKind: MaliciousSiteProtection.ThreatKind, revision: Int) -> MaliciousSiteProtection.APIClient.Response.HashPrefixesChangeSet { - updateHashPrefixesWasCalled = true + updateHashPrefixesCalled?(revision) return hashPrefixRevisions[revision] ?? .init(insert: [], delete: [], revision: revision, replace: false) } func _matches(forHashPrefix hashPrefix: String) -> APIClient.Response.Matches { .init(matches: [ - Match(hostname: "example.com", url: "https://example.com/mal", regex: ".", hash: "a379a6f6eeafb9a55e378c118034e2751e682fab9f2d30ab13d2125586ce1947", category: nil), + Match(hostname: "example.com", url: "https://example.com/mal", regex: ".*", hash: "a379a6f6eeafb9a55e378c118034e2751e682fab9f2d30ab13d2125586ce1947", category: nil), Match(hostname: "test.com", url: "https://test.com/mal", regex: ".*test.*", hash: "aa00bb11aa00cc11bb00cc11", category: nil) ]) } diff --git a/Tests/MaliciousSiteProtectionTests/Mocks/MockMaliciousSiteProtectionDataManager.swift b/Tests/MaliciousSiteProtectionTests/Mocks/MockMaliciousSiteProtectionDataManager.swift index 64d82bbea..1a67ad329 100644 --- a/Tests/MaliciousSiteProtectionTests/Mocks/MockMaliciousSiteProtectionDataManager.swift +++ b/Tests/MaliciousSiteProtectionTests/Mocks/MockMaliciousSiteProtectionDataManager.swift @@ -16,29 +16,25 @@ // limitations under the License. // +import Combine import Foundation -import MaliciousSiteProtection +@testable import MaliciousSiteProtection -public class MockMaliciousSiteProtectionDataManager: MaliciousSiteProtection.DataManaging { - public var filterSet: Set - public var hashPrefixes: Set - public var currentRevision: Int +actor MockMaliciousSiteProtectionDataManager: MaliciousSiteProtection.DataManaging { - public init() { - filterSet = Set() - hashPrefixes = Set() - currentRevision = 0 + @Published var store = [MaliciousSiteProtection.DataManager.StoredDataType: Any]() + func publisher(for key: DataKey) -> AnyPublisher where DataKey: MaliciousSiteProtection.MaliciousSiteDataKey { + $store.map { $0[key.dataType] as? DataKey.DataSet ?? .init(revision: 0, items: []) } + .removeDuplicates() + .eraseToAnyPublisher() } - public func saveFilterSet(set: Set) { - filterSet = set + public func dataSet(for key: DataKey) -> DataKey.DataSet where DataKey: MaliciousSiteProtection.MaliciousSiteDataKey { + return store[key.dataType] as? DataKey.DataSet ?? .init(revision: 0, items: []) } - public func saveHashPrefixes(set: Set) { - hashPrefixes = set + func store(_ dataSet: DataKey.DataSet, for key: DataKey) async where DataKey: MaliciousSiteProtection.MaliciousSiteDataKey { + store[key.dataType] = dataSet } - public func saveRevision(_ revision: Int) { - currentRevision = revision - } } diff --git a/Tests/MaliciousSiteProtectionTests/Mocks/MockMaliciousSiteProtectionEmbeddedDataProvider.swift b/Tests/MaliciousSiteProtectionTests/Mocks/MockMaliciousSiteProtectionEmbeddedDataProvider.swift index 9bb44bbe2..37f6c2962 100644 --- a/Tests/MaliciousSiteProtectionTests/Mocks/MockMaliciousSiteProtectionEmbeddedDataProvider.swift +++ b/Tests/MaliciousSiteProtectionTests/Mocks/MockMaliciousSiteProtectionEmbeddedDataProvider.swift @@ -17,31 +17,66 @@ // import Foundation -import MaliciousSiteProtection -public class MockMaliciousSiteProtectionEmbeddedDataProvider: MaliciousSiteProtection.EmbeddedDataProviding { - public var embeddedRevision: Int = 65 +@testable import MaliciousSiteProtection + +final class MockMaliciousSiteProtectionEmbeddedDataProvider: MaliciousSiteProtection.EmbeddedDataProviding { + var embeddedRevision: Int = 65 var loadHashPrefixesCalled: Bool = false var loadFilterSetCalled: Bool = true - var hashPrefixes: Set = ["aabb"] - var filterSet: Set = [Filter(hash: "dummyhash", regex: "dummyregex")] + var hashPrefixes: Set = [] { + didSet { + hashPrefixesData = try! JSONEncoder().encode(hashPrefixes) + } + } + var hashPrefixesData: Data! + + var filterSet: Set = [] { + didSet { + filterSetData = try! JSONEncoder().encode(filterSet) + } + } + var filterSetData: Data! + + init() { + hashPrefixes = Set(["aabb"]) + filterSet = Set([Filter(hash: "dummyhash", regex: "dummyregex")]) + } - public func shouldReturnFilterSet(set: Set) { - self.filterSet = set + func revision(for detectionKind: MaliciousSiteProtection.DataManager.StoredDataType) -> Int { + embeddedRevision } - public func shouldReturnHashPrefixes(set: Set) { - self.hashPrefixes = set + func url(for dataType: MaliciousSiteProtection.DataManager.StoredDataType) -> URL { + switch dataType { + case .filterSet: + self.loadFilterSetCalled = true + return URL(string: "filterSet")! + case .hashPrefixSet: + self.loadHashPrefixesCalled = true + return URL(string: "hashPrefixSet")! + } } - public func loadEmbeddedFilterSet() -> Set { - self.loadHashPrefixesCalled = true - return self.filterSet + func hash(for dataType: MaliciousSiteProtection.DataManager.StoredDataType) -> String { + let url = url(for: dataType) + let data = try! data(withContentsOf: url) + let sha = data.sha256 + return sha } - public func loadEmbeddedHashPrefixes() -> Set { - self.loadFilterSetCalled = true - return self.hashPrefixes + func data(withContentsOf url: URL) throws -> Data { + let data: Data + switch url.absoluteString { + case "filterSet": + self.loadFilterSetCalled = true + return filterSetData + case "hashPrefixSet": + self.loadHashPrefixesCalled = true + return hashPrefixesData + default: + fatalError("Unexpected url \(url.absoluteString)") + } } } diff --git a/Tests/MaliciousSiteProtectionTests/Mocks/MockPhishingDetectionUpdateManager.swift b/Tests/MaliciousSiteProtectionTests/Mocks/MockPhishingDetectionUpdateManager.swift index 3eb67c06b..b49eac588 100644 --- a/Tests/MaliciousSiteProtectionTests/Mocks/MockPhishingDetectionUpdateManager.swift +++ b/Tests/MaliciousSiteProtectionTests/Mocks/MockPhishingDetectionUpdateManager.swift @@ -17,27 +17,40 @@ // import Foundation -import MaliciousSiteProtection +@testable import MaliciousSiteProtection + +class MockPhishingDetectionUpdateManager: MaliciousSiteProtection.UpdateManaging { -public class MockPhishingDetectionUpdateManager: MaliciousSiteProtection.UpdateManaging { var didUpdateFilterSet = false var didUpdateHashPrefixes = false + var startPeriodicUpdatesCalled = false var completionHandler: (() -> Void)? - public func updateFilterSet() async { + func updateData(for key: some MaliciousSiteProtection.MaliciousSiteDataKey) async { + switch key.dataType { + case .filterSet: await updateFilterSet() + case .hashPrefixSet: await updateHashPrefixes() + } + } + + func updateFilterSet() async { didUpdateFilterSet = true checkCompletion() } - public func updateHashPrefixes() async { + func updateHashPrefixes() async { didUpdateHashPrefixes = true checkCompletion() } - private func checkCompletion() { + func checkCompletion() { if didUpdateFilterSet && didUpdateHashPrefixes { completionHandler?() } } + public func startPeriodicUpdates() -> Task { + startPeriodicUpdatesCalled = true + return Task {} + } } diff --git a/Tests/MaliciousSiteProtectionTests/PhishingDetectionDataActivitiesTests.swift b/Tests/MaliciousSiteProtectionTests/PhishingDetectionDataActivitiesTests.swift deleted file mode 100644 index 3d3ad4a01..000000000 --- a/Tests/MaliciousSiteProtectionTests/PhishingDetectionDataActivitiesTests.swift +++ /dev/null @@ -1,48 +0,0 @@ -// -// PhishingDetectionDataActivitiesTests.swift -// -// Copyright © 2024 DuckDuckGo. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// - -import Foundation -import XCTest -@testable import MaliciousSiteProtection - -class PhishingDetectionDataActivitiesTests: XCTestCase { - var mockUpdateManager: MockPhishingDetectionUpdateManager! - var activities: PhishingDetectionDataActivities! - - override func setUp() { - super.setUp() - mockUpdateManager = MockPhishingDetectionUpdateManager() - activities = PhishingDetectionDataActivities(hashPrefixInterval: 1, filterSetInterval: 1, updateManager: mockUpdateManager) - } - - func testUpdateHashPrefixesAndFilterSetRuns() async { - let expectation = XCTestExpectation(description: "updateHashPrefixes and updateFilterSet completes") - - mockUpdateManager.completionHandler = { - expectation.fulfill() - } - - activities.start() - - await fulfillment(of: [expectation], timeout: 10.0) - - XCTAssertTrue(mockUpdateManager.didUpdateHashPrefixes) - XCTAssertTrue(mockUpdateManager.didUpdateFilterSet) - - } -}