From 50ed41cdbaabdc91dcc4ab65ddd9070909dcd960 Mon Sep 17 00:00:00 2001 From: Alexey Martemyanov Date: Mon, 25 Nov 2024 21:45:35 +0600 Subject: [PATCH 01/15] refactor APIClient --- Package.swift | 1 + Sources/Common/Extensions/URLExtension.swift | 25 +-- .../API/APIClient.swift | 172 +++++++++--------- .../API/APIRequest.swift | 84 +++++++++ .../API/ChangeSetResponse.swift | 7 +- .../API/MatchResponse.swift | 4 + .../MaliciousSiteDetector.swift | 12 +- .../Services/UpdateManager.swift | 32 +++- Sources/Networking/README.md | 12 +- Sources/Networking/v2/APIRequestV2.swift | 48 +++-- Sources/Networking/v2/APIResponseV2.swift | 9 +- .../Extensions/Dictionary+URLQueryItem.swift | 35 ---- Sources/TestUtils/MockAPIService.swift | 12 +- ...aliciousSiteProtectionAPIClientTests.swift | 120 +++++++----- .../Mocks/MockMaliciousSiteDetector.swift | 38 ---- ...MockMaliciousSiteProtectionAPIClient.swift | 25 ++- .../v2/APIRequestV2Tests.swift | 41 ++--- .../NetworkingTests/v2/APIServiceTests.swift | 25 ++- 18 files changed, 387 insertions(+), 315 deletions(-) create mode 100644 Sources/MaliciousSiteProtection/API/APIRequest.swift delete mode 100644 Sources/Networking/v2/Extensions/Dictionary+URLQueryItem.swift delete mode 100644 Tests/MaliciousSiteProtectionTests/Mocks/MockMaliciousSiteDetector.swift diff --git a/Package.swift b/Package.swift index eadd9b6d1..c074ced7e 100644 --- a/Package.swift +++ b/Package.swift @@ -651,6 +651,7 @@ let package = Package( .testTarget( name: "MaliciousSiteProtectionTests", dependencies: [ + "TestUtils", "MaliciousSiteProtection", ], resources: [ diff --git a/Sources/Common/Extensions/URLExtension.swift b/Sources/Common/Extensions/URLExtension.swift index d19751148..78cc2881f 100644 --- a/Sources/Common/Extensions/URLExtension.swift +++ b/Sources/Common/Extensions/URLExtension.swift @@ -354,22 +354,24 @@ extension URL { // MARK: - Parameters + @_disfavoredOverload // prefer ordered KeyValuePairs collection when `parameters` passed as a Dictionary literal to preserve order. public func appendingParameters(_ parameters: QueryParams, allowedReservedCharacters: CharacterSet? = nil) -> URL where QueryParams.Element == (key: String, value: String) { + let result = self.appending(percentEncodedQueryItems: parameters.map { name, value in + URLQueryItem(percentEncodingName: name, value: value, withAllowedCharacters: allowedReservedCharacters) + }) + return result + } - return parameters.reduce(self) { partialResult, parameter in - partialResult.appendingParameter( - name: parameter.key, - value: parameter.value, - allowedReservedCharacters: allowedReservedCharacters - ) - } + public func appendingParameters(_ parameters: KeyValuePairs, allowedReservedCharacters: CharacterSet? = nil) -> URL { + let result = self.appending(percentEncodedQueryItems: parameters.map { name, value in + URLQueryItem(percentEncodingName: name, value: value, withAllowedCharacters: allowedReservedCharacters) + }) + return result } public func appendingParameter(name: String, value: String, allowedReservedCharacters: CharacterSet? = nil) -> URL { - let queryItem = URLQueryItem(percentEncodingName: name, - value: value, - withAllowedCharacters: allowedReservedCharacters) + let queryItem = URLQueryItem(percentEncodingName: name, value: value, withAllowedCharacters: allowedReservedCharacters) return self.appending(percentEncodedQueryItem: queryItem) } @@ -383,8 +385,9 @@ extension URL { var existingPercentEncodedQueryItems = components.percentEncodedQueryItems ?? [URLQueryItem]() existingPercentEncodedQueryItems.append(contentsOf: percentEncodedQueryItems) components.percentEncodedQueryItems = existingPercentEncodedQueryItems + let result = components.url ?? self - return components.url ?? self + return result } public func getQueryItems() -> [URLQueryItem]? { diff --git a/Sources/MaliciousSiteProtection/API/APIClient.swift b/Sources/MaliciousSiteProtection/API/APIClient.swift index 50ddfbd6a..678171839 100644 --- a/Sources/MaliciousSiteProtection/API/APIClient.swift +++ b/Sources/MaliciousSiteProtection/API/APIClient.swift @@ -18,129 +18,119 @@ import Common import Foundation -import os import Networking public protocol APIClientProtocol { - func getFilterSet(revision: Int) async -> APIClient.FiltersChangeSetResponse - func getHashPrefixes(revision: Int) async -> APIClient.HashPrefixesChangeSetResponse - func getMatches(hashPrefix: String) async -> [Match] + func load(_ requestConfig: Request) async throws -> Request.ResponseType } -public protocol URLSessionProtocol { - func data(for request: URLRequest) async throws -> (Data, URLResponse) +public extension APIClientProtocol where Self == APIClient { + static var production: APIClientProtocol { APIClient(environment: .production) } + static var staging: APIClientProtocol { APIClient(environment: .staging) } } -extension URLSession: URLSessionProtocol {} - -extension URLSessionProtocol { - public static var defaultSession: URLSessionProtocol { - return URLSession.shared - } +public protocol APIClientEnvironment { + func headers(for request: APIClient.Request) -> APIRequestV2.HeadersV2 + func url(for request: APIClient.Request) -> URL } -public struct APIClient: APIClientProtocol { +public extension APIClient { + enum DefaultEnvironment: APIClientEnvironment { - public enum Environment { case production case staging - } + case dev - enum Constants { - static let productionEndpoint = URL(string: "https://duckduckgo.com/api/protection/")! - static let stagingEndpoint = URL(string: "https://staging.duckduckgo.com/api/protection/")! - enum APIPath: String { - case filterSet - case hashPrefix - case matches + var endpoint: URL { + switch self { + case .production: URL(string: "https://duckduckgo.com/api/protection/")! + case .staging: URL(string: "https://staging.duckduckgo.com/api/protection/")! + case .dev: URL(string: "https://4842-20-93-28-24.ngrok-free.app/api/protection/")! + } } - } - private let endpointURL: URL - private let session: URLSessionProtocol! - private var headers: [String: String]? = [:] + var defaultHeaders: APIRequestV2.HeadersV2 { + .init(userAgent: APIRequest.Headers.userAgent) + } - var filterSetURL: URL { - endpointURL.appendingPathComponent(Constants.APIPath.filterSet.rawValue) - } + enum APIPath { + static let filterSet = "filterSet" + static let hashPrefix = "hashPrefix" + static let matches = "matches" + } - var hashPrefixURL: URL { - endpointURL.appendingPathComponent(Constants.APIPath.hashPrefix.rawValue) - } + enum QueryParameter { + static let category = "category" + static let revision = "revision" + static let hashPrefix = "hashPrefix" + } - var matchesURL: URL { - endpointURL.appendingPathComponent(Constants.APIPath.matches.rawValue) - } + public func url(for request: APIClient.Request) -> URL { + switch request { + case .hashPrefixSet(let configuration): + endpoint.appendingPathComponent(APIPath.hashPrefix).appendingParameters([ + QueryParameter.category: configuration.threatKind.rawValue, + QueryParameter.revision: (configuration.revision ?? 0).description, + ]) + case .filterSet(let configuration): + endpoint.appendingPathComponent(APIPath.filterSet).appendingParameters([ + QueryParameter.category: configuration.threatKind.rawValue, + QueryParameter.revision: (configuration.revision ?? 0).description, + ]) + case .matches(let configuration): + endpoint.appendingPathComponent(APIPath.matches).appendingParameter(name: QueryParameter.hashPrefix, value: configuration.hashPrefix) + } + } - public init(environment: Environment = .production, session: URLSessionProtocol = URLSession.defaultSession) { - switch environment { - case .production: - endpointURL = Constants.productionEndpoint - case .staging: - endpointURL = Constants.stagingEndpoint + public func headers(for request: APIClient.Request) -> APIRequestV2.HeadersV2 { + defaultHeaders } - self.session = session } - public func getFilterSet(revision: Int) async -> FiltersChangeSetResponse { - guard let url = createURL(for: .filterSet, revision: revision) else { - logDebug("🔸 Invalid filterSet revision URL: \(revision)") - return FiltersChangeSetResponse(insert: [], delete: [], revision: revision, replace: false) - } - return await fetch(url: url, responseType: FiltersChangeSetResponse.self) ?? FiltersChangeSetResponse(insert: [], delete: [], revision: revision, replace: false) +} + +public struct APIClient: APIClientProtocol { + + let environment: APIClientEnvironment + private let service: APIService + + public init(environment: Self.DefaultEnvironment = .production, service: APIService = DefaultAPIService(urlSession: .shared)) { + self.init(environment: environment as APIClientEnvironment, service: service) } - public func getHashPrefixes(revision: Int) async -> HashPrefixesChangeSetResponse { - guard let url = createURL(for: .hashPrefix, revision: revision) else { - logDebug("🔸 Invalid hashPrefix revision URL: \(revision)") - return HashPrefixesChangeSetResponse(insert: [], delete: [], revision: revision, replace: false) - } - return await fetch(url: url, responseType: HashPrefixesChangeSetResponse.self) ?? HashPrefixesChangeSetResponse(insert: [], delete: [], revision: revision, replace: false) + public init(environment: APIClientEnvironment, service: APIService) { + self.environment = environment + self.service = service } - public func getMatches(hashPrefix: String) async -> [Match] { - let queryItems = [URLQueryItem(name: "hashPrefix", value: hashPrefix)] - guard let url = createURL(for: .matches, queryItems: queryItems) else { - logDebug("🔸 Invalid matches URL: \(hashPrefix)") - return [] - } - return await fetch(url: url, responseType: MatchResponse.self)?.matches ?? [] + public func load(_ requestConfig: Request) async throws -> Request.ResponseType { + let requestType = requestConfig.requestType + let headers = environment.headers(for: requestType) + let url = environment.url(for: requestType) + + let apiRequest = APIRequestV2(url: url, method: .get, headers: headers) + let response = try await service.fetch(request: apiRequest) + let result: Request.ResponseType = try response.decodeBody() + + return result } -} -// MARK: Private Methods -extension APIClient { +} - private func logDebug(_ message: String) { - Logger.api.debug("\(message)") +// MARK: - Convenience +extension APIClientProtocol { + public func filtersChangeSet(for threatKind: ThreatKind, revision: Int) async throws -> APIClient.Response.FiltersChangeSet { + let result = try await load(.filterSet(threatKind: threatKind, revision: revision)) + return result } - private func createURL(for path: Constants.APIPath, revision: Int? = nil, queryItems: [URLQueryItem]? = nil) -> URL? { - // Start with the base URL and append the path component - var urlComponents = URLComponents(url: endpointURL.appendingPathComponent(path.rawValue), resolvingAgainstBaseURL: true) - var items = queryItems ?? [] - if let revision = revision, revision > 0 { - items.append(URLQueryItem(name: "revision", value: String(revision))) - } - urlComponents?.queryItems = items.isEmpty ? nil : items - return urlComponents?.url + public func hashPrefixesChangeSet(for threatKind: ThreatKind, revision: Int) async throws -> APIClient.Response.HashPrefixesChangeSet { + let result = try await load(.hashPrefixes(threatKind: threatKind, revision: revision)) + return result } - private func fetch(url: URL, responseType: T.Type) async -> T? { - var request = URLRequest(url: url) - request.httpMethod = "GET" - request.allHTTPHeaderFields = headers - - do { - let (data, _) = try await session.data(for: request) - if let response = try? JSONDecoder().decode(responseType, from: data) { - return response - } else { - logDebug("🔸 Failed to decode response for \(String(describing: responseType)): \(data)") - } - } catch { - logDebug("🔴 Failed to load \(String(describing: responseType)) data: \(error)") - } - return nil + public func matches(forHashPrefix hashPrefix: String) async throws -> APIClient.Response.Matches { + let result = try await load(.matches(hashPrefix: hashPrefix)) + return result } } diff --git a/Sources/MaliciousSiteProtection/API/APIRequest.swift b/Sources/MaliciousSiteProtection/API/APIRequest.swift new file mode 100644 index 000000000..af18fb2f2 --- /dev/null +++ b/Sources/MaliciousSiteProtection/API/APIRequest.swift @@ -0,0 +1,84 @@ +// +// APIRequest.swift +// +// Copyright © 2024 DuckDuckGo. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +import Foundation + +public protocol APIRequestProtocol { + associatedtype ResponseType: Decodable + var requestType: APIClient.Request { get } +} + +public extension APIClient { + enum Request { + case hashPrefixSet(HashPrefixes) + case filterSet(FilterSet) + case matches(Matches) + } +} +public extension APIClient.Request { + struct HashPrefixes: APIRequestProtocol { + public typealias ResponseType = APIClient.Response.HashPrefixesChangeSet + + public let threatKind: ThreatKind + public let revision: Int? + + public var requestType: APIClient.Request { + .hashPrefixSet(self) + } + } +} +extension APIRequestProtocol where Self == APIClient.Request.HashPrefixes { + static func hashPrefixes(threatKind: ThreatKind, revision: Int?) -> Self { + .init(threatKind: threatKind, revision: revision) + } +} + +public extension APIClient.Request { + struct FilterSet: APIRequestProtocol { + public typealias ResponseType = APIClient.Response.FiltersChangeSet + + public let threatKind: ThreatKind + public let revision: Int? + + public var requestType: APIClient.Request { + .filterSet(self) + } + } +} +extension APIRequestProtocol where Self == APIClient.Request.FilterSet { + static func filterSet(threatKind: ThreatKind, revision: Int?) -> Self { + .init(threatKind: threatKind, revision: revision) + } +} + +public extension APIClient.Request { + struct Matches: APIRequestProtocol { + public typealias ResponseType = APIClient.Response.Matches + + public let hashPrefix: String + + public var requestType: APIClient.Request { + .matches(self) + } + } +} +extension APIRequestProtocol where Self == APIClient.Request.Matches { + static func matches(hashPrefix: String) -> Self { + .init(hashPrefix: hashPrefix) + } +} diff --git a/Sources/MaliciousSiteProtection/API/ChangeSetResponse.swift b/Sources/MaliciousSiteProtection/API/ChangeSetResponse.swift index 732411895..7988c2a32 100644 --- a/Sources/MaliciousSiteProtection/API/ChangeSetResponse.swift +++ b/Sources/MaliciousSiteProtection/API/ChangeSetResponse.swift @@ -34,7 +34,10 @@ extension APIClient { } } - public typealias FiltersChangeSetResponse = ChangeSetResponse - public typealias HashPrefixesChangeSetResponse = ChangeSetResponse + public enum Response { + public typealias FiltersChangeSet = ChangeSetResponse + public typealias HashPrefixesChangeSet = ChangeSetResponse + public typealias Matches = MatchResponse + } } diff --git a/Sources/MaliciousSiteProtection/API/MatchResponse.swift b/Sources/MaliciousSiteProtection/API/MatchResponse.swift index aaa48b388..2cb6df962 100644 --- a/Sources/MaliciousSiteProtection/API/MatchResponse.swift +++ b/Sources/MaliciousSiteProtection/API/MatchResponse.swift @@ -20,6 +20,10 @@ extension APIClient { public struct MatchResponse: Codable, Equatable { public var matches: [Match] + + public init(matches: [Match]) { + self.matches = matches + } } } diff --git a/Sources/MaliciousSiteProtection/MaliciousSiteDetector.swift b/Sources/MaliciousSiteProtection/MaliciousSiteDetector.swift index 9ac4c01c2..592e7e852 100644 --- a/Sources/MaliciousSiteProtection/MaliciousSiteDetector.swift +++ b/Sources/MaliciousSiteProtection/MaliciousSiteDetector.swift @@ -41,10 +41,6 @@ public final class MaliciousSiteDetector: MaliciousSiteDetecting { self.eventMapping = eventMapping } - private func getMatches(hashPrefix: String) async -> Set { - return Set(await apiClient.getMatches(hashPrefix: hashPrefix)) - } - private func inFilterSet(hash: String) -> Set { return Set(dataManager.filterSet.filter { $0.hash == hash }) } @@ -65,7 +61,13 @@ public final class MaliciousSiteDetector: MaliciousSiteDetecting { } private func fetchMatches(hashPrefix: String) async -> [Match] { - return await apiClient.getMatches(hashPrefix: hashPrefix) + do { + let response = try await apiClient.matches(forHashPrefix: hashPrefix) + return response.matches + } catch { + Logger.api.error("Failed to fetch matches for hash prefix: \(hashPrefix): \(error.localizedDescription)") + return [] + } } private func checkLocalFilters(canonicalHost: String, canonicalUrl: URL) -> Bool { diff --git a/Sources/MaliciousSiteProtection/Services/UpdateManager.swift b/Sources/MaliciousSiteProtection/Services/UpdateManager.swift index 053acd230..af9f60e7d 100644 --- a/Sources/MaliciousSiteProtection/Services/UpdateManager.swift +++ b/Sources/MaliciousSiteProtection/Services/UpdateManager.swift @@ -54,30 +54,42 @@ public struct UpdateManager: UpdateManaging { } public func updateFilterSet() async { - let response = await apiClient.getFilterSet(revision: dataManager.currentRevision) + let changeSet: APIClient.Response.FiltersChangeSet + do { + changeSet = try await apiClient.filtersChangeSet(for: .phishing, revision: dataManager.currentRevision) + } catch { + Logger.updateManager.error("error fetching filter set: \(error)") + return + } updateSet( currentSet: dataManager.filterSet, - insert: response.insert, - delete: response.delete, - replace: response.replace + insert: changeSet.insert, + delete: changeSet.delete, + replace: changeSet.replace ) { newSet in self.dataManager.saveFilterSet(set: newSet) } - dataManager.saveRevision(response.revision) + dataManager.saveRevision(changeSet.revision) Logger.updateManager.debug("filterSet updated to revision \(self.dataManager.currentRevision)") } public func updateHashPrefixes() async { - let response = await apiClient.getHashPrefixes(revision: dataManager.currentRevision) + let changeSet: APIClient.Response.HashPrefixesChangeSet + do { + changeSet = try await apiClient.hashPrefixesChangeSet(for: .phishing, revision: dataManager.currentRevision) + } catch { + Logger.updateManager.error("error fetching hash prefixes: \(error)") + return + } updateSet( currentSet: dataManager.hashPrefixes, - insert: response.insert, - delete: response.delete, - replace: response.replace + insert: changeSet.insert, + delete: changeSet.delete, + replace: changeSet.replace ) { newSet in self.dataManager.saveHashPrefixes(set: newSet) } - dataManager.saveRevision(response.revision) + dataManager.saveRevision(changeSet.revision) Logger.updateManager.debug("hashPrefixes updated to revision \(self.dataManager.currentRevision)") } } diff --git a/Sources/Networking/README.md b/Sources/Networking/README.md index 83a2c5ce3..751ee63d8 100644 --- a/Sources/Networking/README.md +++ b/Sources/Networking/README.md @@ -19,7 +19,7 @@ let request = APIRequestV2(url: HTTPURLResponse.testUrl, responseConstraints: [.allowHTTPNotModified, .requireETagHeader, .requireUserAgent], - allowedQueryReservedCharacters: CharacterSet(charactersIn: ","))! + allowedQueryReservedCharacters: CharacterSet(charactersIn: ",")) let apiService = DefaultAPIService(urlSession: URLSession.shared) ``` @@ -55,12 +55,12 @@ The `MockPIService` implementing `APIService` can be found in `BSK/TestUtils` ``` let apiResponse = (Data(), HTTPURLResponse(url: HTTPURLResponse.testUrl, - statusCode: 200, - httpVersion: nil, - headerFields: nil)!) -let mockedAPIService = MockAPIService(decodableResponse: Result.failure(SomeError.testError), apiResponse: Result.success(apiResponse) ) + statusCode: 200, + httpVersion: nil, + headerFields: nil)!) +let mockedAPIService = MockAPIService(apiResponse: Result.success(apiResponse)) ``` ## v1 (Legacy) -Not to be used. All V1 public functions have been deprecated and maintained only for backward compatibility. \ No newline at end of file +Not to be used. All V1 public functions have been deprecated and maintained only for backward compatibility. diff --git a/Sources/Networking/v2/APIRequestV2.swift b/Sources/Networking/v2/APIRequestV2.swift index 07434de67..a61604861 100644 --- a/Sources/Networking/v2/APIRequestV2.swift +++ b/Sources/Networking/v2/APIRequestV2.swift @@ -16,12 +16,11 @@ // limitations under the License. // +import Common import Foundation public struct APIRequestV2: CustomDebugStringConvertible { - public typealias QueryItems = [String: String] - let timeoutInterval: TimeInterval let responseConstraints: [APIResponseConstraints]? public let urlRequest: URLRequest @@ -37,25 +36,25 @@ public struct APIRequestV2: CustomDebugStringConvertible { /// - cachePolicy: The request cache policy, default is `.useProtocolCachePolicy` /// - responseRequirements: The response requirements /// - allowedQueryReservedCharacters: The characters in this character set will not be URL encoded in the query parameters - public init?(url: URL, - method: HTTPRequestMethod = .get, - queryItems: QueryItems? = nil, - headers: APIRequestV2.HeadersV2? = APIRequestV2.HeadersV2(), - body: Data? = nil, - timeoutInterval: TimeInterval = 60.0, - cachePolicy: URLRequest.CachePolicy? = nil, - responseConstraints: [APIResponseConstraints]? = nil, - allowedQueryReservedCharacters: CharacterSet? = nil) { + public init( + url: URL, + method: HTTPRequestMethod = .get, + queryItems: QueryParams?, + headers: APIRequestV2.HeadersV2? = APIRequestV2.HeadersV2(), + body: Data? = nil, + timeoutInterval: TimeInterval = 60.0, + cachePolicy: URLRequest.CachePolicy? = nil, + responseConstraints: [APIResponseConstraints]? = nil, + allowedQueryReservedCharacters: CharacterSet? = nil + ) where QueryParams.Element == (key: String, value: String) { + self.timeoutInterval = timeoutInterval self.responseConstraints = responseConstraints - // Generate URL request - guard var urlComps = URLComponents(url: url, resolvingAgainstBaseURL: true) else { - return nil - } - urlComps.queryItems = queryItems?.toURLQueryItems(allowedReservedCharacters: allowedQueryReservedCharacters) - guard let finalURL = urlComps.url else { - return nil + let finalURL = if let queryItems { + url.appendingParameters(queryItems, allowedReservedCharacters: allowedQueryReservedCharacters) + } else { + url } var request = URLRequest(url: finalURL, timeoutInterval: timeoutInterval) request.allHTTPHeaderFields = headers?.httpHeaders @@ -67,6 +66,19 @@ public struct APIRequestV2: CustomDebugStringConvertible { self.urlRequest = request } + public init( + url: URL, + method: HTTPRequestMethod = .get, + headers: APIRequestV2.HeadersV2? = APIRequestV2.HeadersV2(), + body: Data? = nil, + timeoutInterval: TimeInterval = 60.0, + cachePolicy: URLRequest.CachePolicy? = nil, + responseConstraints: [APIResponseConstraints]? = nil, + allowedQueryReservedCharacters: CharacterSet? = nil + ) { + self.init(url: url, method: method, queryItems: [String: String]?.none, headers: headers, body: body, timeoutInterval: timeoutInterval, cachePolicy: cachePolicy, responseConstraints: responseConstraints, allowedQueryReservedCharacters: allowedQueryReservedCharacters) + } + public var debugDescription: String { """ APIRequestV2: diff --git a/Sources/Networking/v2/APIResponseV2.swift b/Sources/Networking/v2/APIResponseV2.swift index 1b178fd93..8987e377b 100644 --- a/Sources/Networking/v2/APIResponseV2.swift +++ b/Sources/Networking/v2/APIResponseV2.swift @@ -20,8 +20,13 @@ import Foundation import os.log public struct APIResponseV2 { - let data: Data? - let httpResponse: HTTPURLResponse + public let data: Data? + public let httpResponse: HTTPURLResponse + + public init(data: Data?, httpResponse: HTTPURLResponse) { + self.data = data + self.httpResponse = httpResponse + } } public extension APIResponseV2 { diff --git a/Sources/Networking/v2/Extensions/Dictionary+URLQueryItem.swift b/Sources/Networking/v2/Extensions/Dictionary+URLQueryItem.swift deleted file mode 100644 index 81a4648d6..000000000 --- a/Sources/Networking/v2/Extensions/Dictionary+URLQueryItem.swift +++ /dev/null @@ -1,35 +0,0 @@ -// -// Dictionary+URLQueryItem.swift -// -// Copyright © 2024 DuckDuckGo. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// - -import Foundation -import Common - -extension Dictionary where Key == String, Value == String { - - func toURLQueryItems(allowedReservedCharacters: CharacterSet? = nil) -> [URLQueryItem] { - return self.map { - if let allowedReservedCharacters { - URLQueryItem(percentEncodingName: $0.key, - value: $0.value, - withAllowedCharacters: allowedReservedCharacters) - } else { - URLQueryItem(name: $0.key, value: $0.value) - } - } - } -} diff --git a/Sources/TestUtils/MockAPIService.swift b/Sources/TestUtils/MockAPIService.swift index be4bf47a2..f4d35b4b6 100644 --- a/Sources/TestUtils/MockAPIService.swift +++ b/Sources/TestUtils/MockAPIService.swift @@ -19,12 +19,16 @@ import Foundation import Networking -public struct MockAPIService: APIService { +public class MockAPIService: APIService { - public var apiResponse: Result + public var requestHandler: ((APIRequestV2) -> Result)! - public func fetch(request: Networking.APIRequestV2) async throws -> APIResponseV2 { - switch apiResponse { + public init(requestHandler: ((APIRequestV2) -> Result)? = nil) { + self.requestHandler = requestHandler + } + + public func fetch(request: APIRequestV2) async throws -> APIResponseV2 { + switch requestHandler!(request) { case .success(let result): return result case .failure(let error): diff --git a/Tests/MaliciousSiteProtectionTests/MaliciousSiteProtectionAPIClientTests.swift b/Tests/MaliciousSiteProtectionTests/MaliciousSiteProtectionAPIClientTests.swift index 52f74b97c..e1fd6db4f 100644 --- a/Tests/MaliciousSiteProtectionTests/MaliciousSiteProtectionAPIClientTests.swift +++ b/Tests/MaliciousSiteProtectionTests/MaliciousSiteProtectionAPIClientTests.swift @@ -16,110 +16,130 @@ // limitations under the License. // import Foundation +import Networking +import TestUtils import XCTest + @testable import MaliciousSiteProtection final class MaliciousSiteProtectionAPIClientTests: XCTestCase { - var mockSession: MockURLSession! + var mockService: MockAPIService! var client: MaliciousSiteProtection.APIClient! override func setUp() { super.setUp() - mockSession = MockURLSession() - client = .init(environment: .staging, session: mockSession) + mockService = MockAPIService() + client = .init(environment: .staging, service: mockService) } override func tearDown() { - mockSession = nil + mockService = nil client = nil super.tearDown() } - func testGetFilterSetSuccess() async { + func testWhenPhishingFilterSetRequestedAndSucceeds_ChangeSetIsReturned() async throws { // Given let insertFilter = Filter(hash: "a379a6f6eeafb9a55e378c118034e2751e682fab9f2d30ab13d2125586ce1947", regex: ".") let deleteFilter = Filter(hash: "6a929cd0b3ba4677eaedf1b2bdaf3ff89281cca94f688c83103bc9a676aea46d", regex: "(?i)^https?\\:\\/\\/[\\w\\-\\.]+(?:\\:(?:80|443))?") - let expectedResponse = APIClient.FiltersChangeSetResponse(insert: [insertFilter], delete: [deleteFilter], revision: 1, replace: false) - mockSession.data = try? JSONEncoder().encode(expectedResponse) - mockSession.response = HTTPURLResponse(url: client.filterSetURL, statusCode: 200, httpVersion: nil, headerFields: nil) + let expectedResponse = APIClient.Response.FiltersChangeSet(insert: [insertFilter], delete: [deleteFilter], revision: 666, replace: false) + mockService.requestHandler = { [unowned self] in + XCTAssertEqual($0.urlRequest.url, client.environment.url(for: .filterSet(.init(threatKind: .phishing, revision: 666)))) + let data = try? JSONEncoder().encode(expectedResponse) + let response = HTTPURLResponse(url: $0.urlRequest.url!, statusCode: 200, httpVersion: nil, headerFields: nil)! + return .success(.init(data: data, httpResponse: response)) + } // When - let response = await client.getFilterSet(revision: 1) + let response = try await client.filtersChangeSet(for: .phishing, revision: 666) // Then XCTAssertEqual(response, expectedResponse) } - func testGetHashPrefixesSuccess() async { + func testWhenHashPrefixesRequestedAndSucceeds_ChangeSetIsReturned() async throws { // Given - let expectedResponse = APIClient.HashPrefixesChangeSetResponse(insert: ["abc"], delete: ["def"], revision: 1, replace: false) - mockSession.data = try? JSONEncoder().encode(expectedResponse) - mockSession.response = HTTPURLResponse(url: client.hashPrefixURL, statusCode: 200, httpVersion: nil, headerFields: nil) + let expectedResponse = APIClient.Response.HashPrefixesChangeSet(insert: ["abc"], delete: ["def"], revision: 1, replace: false) + mockService.requestHandler = { [unowned self] in + XCTAssertEqual($0.urlRequest.url, client.environment.url(for: .hashPrefixSet(.init(threatKind: .phishing, revision: 1)))) + let data = try? JSONEncoder().encode(expectedResponse) + let response = HTTPURLResponse(url: $0.urlRequest.url!, statusCode: 200, httpVersion: nil, headerFields: nil)! + return .success(.init(data: data, httpResponse: response)) + } // When - let response = await client.getHashPrefixes(revision: 1) + let response = try await client.hashPrefixesChangeSet(for: .phishing, revision: 1) // Then XCTAssertEqual(response, expectedResponse) } - func testGetMatchesSuccess() async { + func testWhenMatchesRequestedAndSucceeds_MatchesAreReturned() async throws { // Given - let expectedResponse = APIClient.MatchResponse(matches: [Match(hostname: "example.com", url: "https://example.com/test", regex: ".", hash: "a379a6f6eeafb9a55e378c118034e2751e682fab9f2d30ab13d2125586ce1947", category: nil)]) - mockSession.data = try? JSONEncoder().encode(expectedResponse) - mockSession.response = HTTPURLResponse(url: client.matchesURL, statusCode: 200, httpVersion: nil, headerFields: nil) + let expectedResponse = APIClient.Response.Matches(matches: [Match(hostname: "example.com", url: "https://example.com/test", regex: ".", hash: "a379a6f6eeafb9a55e378c118034e2751e682fab9f2d30ab13d2125586ce1947", category: nil)]) + mockService.requestHandler = { [unowned self] in + XCTAssertEqual($0.urlRequest.url, client.environment.url(for: .matches(.init(hashPrefix: "abc")))) + let data = try? JSONEncoder().encode(expectedResponse) + let response = HTTPURLResponse(url: $0.urlRequest.url!, statusCode: 200, httpVersion: nil, headerFields: nil)! + return .success(.init(data: data, httpResponse: response)) + } // When - let response = await client.getMatches(hashPrefix: "abc") + let response = try await client.matches(forHashPrefix: "abc") // Then - XCTAssertEqual(response, expectedResponse.matches) + XCTAssertEqual(response.matches, expectedResponse.matches) } - func testGetFilterSetInvalidURL() async { + func testWhenHashPrefixesRequestFails_ErrorThrown() async throws { // Given let invalidRevision = -1 + mockService.requestHandler = { + // Simulate a failure or invalid request + let response = HTTPURLResponse(url: $0.urlRequest.url!, statusCode: 400, httpVersion: nil, headerFields: nil)! + return .success(.init(data: nil, httpResponse: response)) + } - // When - let response = await client.getFilterSet(revision: invalidRevision) - - // Then - XCTAssertEqual(response, .init(insert: [], delete: [], revision: invalidRevision, replace: false)) + do { + let response = try await client.hashPrefixesChangeSet(for: .phishing, revision: invalidRevision) + XCTFail("Unexpected \(response) expected throw") + } catch { + } } - func testGetHashPrefixesInvalidURL() async { + func testWhenFilterSetRequestFails_ErrorThrown() async throws { // Given let invalidRevision = -1 + mockService.requestHandler = { + // Simulate a failure or invalid request + let response = HTTPURLResponse(url: $0.urlRequest.url!, statusCode: 400, httpVersion: nil, headerFields: nil)! + return .success(.init(data: nil, httpResponse: response)) + } - // When - let response = await client.getHashPrefixes(revision: invalidRevision) - - // Then - XCTAssertEqual(response, .init(insert: [], delete: [], revision: invalidRevision, replace: false)) + do { + let response = try await client.hashPrefixesChangeSet(for: .phishing, revision: invalidRevision) + XCTFail("Unexpected \(response) expected throw") + } catch { + } } - func testGetMatchesInvalidURL() async { + + func testWhenMatchesRequestFails_ErrorThrown() async throws { // Given let invalidHashPrefix = "" + mockService.requestHandler = { + // Simulate a failure or invalid request + let response = HTTPURLResponse(url: $0.urlRequest.url!, statusCode: 400, httpVersion: nil, headerFields: nil)! + return .success(.init(data: nil, httpResponse: response)) + } - // When - let response = await client.getMatches(hashPrefix: invalidHashPrefix) - - // Then - XCTAssertTrue(response.isEmpty) - } -} - -class MockURLSession: URLSessionProtocol { - var data: Data? - var response: URLResponse? - var error: Error? - - func data(for request: URLRequest) async throws -> (Data, URLResponse) { - if let error = error { - throw error + do { + let response = try await client.matches(forHashPrefix: invalidHashPrefix) + XCTFail("Unexpected \(response) expected throw") + } catch { } - return (data ?? Data(), response ?? URLResponse()) } + } + diff --git a/Tests/MaliciousSiteProtectionTests/Mocks/MockMaliciousSiteDetector.swift b/Tests/MaliciousSiteProtectionTests/Mocks/MockMaliciousSiteDetector.swift deleted file mode 100644 index 0d54cd459..000000000 --- a/Tests/MaliciousSiteProtectionTests/Mocks/MockMaliciousSiteDetector.swift +++ /dev/null @@ -1,38 +0,0 @@ -// -// MockMaliciousSiteDetector.swift -// -// Copyright © 2024 DuckDuckGo. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// - -import Foundation -import MaliciousSiteProtection - -public class MockMaliciousSiteDetector: MaliciousSiteDetecting { - private var mockClient: MaliciousSiteProtection.APIClientProtocol - public var didCallIsMalicious: Bool = false - - init() { - self.mockClient = MockMaliciousSiteProtectionAPIClient() - } - - public func getMatches(hashPrefix: String) async -> Set { - let matches = await mockClient.getMatches(hashPrefix: hashPrefix) - return Set(matches) - } - - public func evaluate(_ url: URL) async -> ThreatKind? { - return url.absoluteString.contains("malicious") ? .phishing : nil - } -} diff --git a/Tests/MaliciousSiteProtectionTests/Mocks/MockMaliciousSiteProtectionAPIClient.swift b/Tests/MaliciousSiteProtectionTests/Mocks/MockMaliciousSiteProtectionAPIClient.swift index ad2a31fe9..d47552491 100644 --- a/Tests/MaliciousSiteProtectionTests/Mocks/MockMaliciousSiteProtectionAPIClient.swift +++ b/Tests/MaliciousSiteProtectionTests/Mocks/MockMaliciousSiteProtectionAPIClient.swift @@ -23,7 +23,7 @@ public class MockMaliciousSiteProtectionAPIClient: MaliciousSiteProtection.APICl public var updateHashPrefixesWasCalled: Bool = false public var updateFilterSetsWasCalled: Bool = false - private var filterRevisions: [Int: APIClient.FiltersChangeSetResponse] = [ + private var filterRevisions: [Int: APIClient.Response.FiltersChangeSet] = [ 0: .init(insert: [ Filter(hash: "testhash1", regex: ".*example.*"), Filter(hash: "testhash2", regex: ".*test.*") @@ -45,7 +45,7 @@ public class MockMaliciousSiteProtectionAPIClient: MaliciousSiteProtection.APICl ], revision: 4, replace: false) ] - private var hashPrefixRevisions: [Int: APIClient.HashPrefixesChangeSetResponse] = [ + private var hashPrefixRevisions: [Int: APIClient.Response.HashPrefixesChangeSet] = [ 0: .init(insert: [ "aa00bb11", "bb00cc11", @@ -65,20 +65,31 @@ public class MockMaliciousSiteProtectionAPIClient: MaliciousSiteProtection.APICl ], revision: 4, replace: false) ] - public func getFilterSet(revision: Int) async -> APIClient.FiltersChangeSetResponse { + public func load(_ requestConfig: Request) async throws -> Request.ResponseType where Request : APIRequestProtocol { + switch requestConfig.requestType { + case .hashPrefixSet(let configuration): + return _hashPrefixesChangeSet(for: configuration.threatKind, revision: configuration.revision ?? 0) as! Request.ResponseType + case .filterSet(let configuration): + return _filtersChangeSet(for: configuration.threatKind, revision: configuration.revision ?? 0) as! Request.ResponseType + case .matches(let configuration): + return _matches(forHashPrefix: configuration.hashPrefix) as! Request.ResponseType + } + } + func _filtersChangeSet(for threatKind: MaliciousSiteProtection.ThreatKind, revision: Int) -> MaliciousSiteProtection.APIClient.Response.FiltersChangeSet { updateFilterSetsWasCalled = true return filterRevisions[revision] ?? .init(insert: [], delete: [], revision: revision, replace: false) } - public func getHashPrefixes(revision: Int) async -> APIClient.HashPrefixesChangeSetResponse { + func _hashPrefixesChangeSet(for threatKind: MaliciousSiteProtection.ThreatKind, revision: Int) -> MaliciousSiteProtection.APIClient.Response.HashPrefixesChangeSet { updateHashPrefixesWasCalled = true return hashPrefixRevisions[revision] ?? .init(insert: [], delete: [], revision: revision, replace: false) } - public func getMatches(hashPrefix: String) async -> [Match] { - return [ + func _matches(forHashPrefix hashPrefix: String) -> APIClient.Response.Matches { + .init(matches: [ Match(hostname: "example.com", url: "https://example.com/mal", regex: ".", hash: "a379a6f6eeafb9a55e378c118034e2751e682fab9f2d30ab13d2125586ce1947", category: nil), Match(hostname: "test.com", url: "https://test.com/mal", regex: ".*test.*", hash: "aa00bb11aa00cc11bb00cc11", category: nil) - ] + ]) } + } diff --git a/Tests/NetworkingTests/v2/APIRequestV2Tests.swift b/Tests/NetworkingTests/v2/APIRequestV2Tests.swift index 59eeadebb..4ec1b8b59 100644 --- a/Tests/NetworkingTests/v2/APIRequestV2Tests.swift +++ b/Tests/NetworkingTests/v2/APIRequestV2Tests.swift @@ -41,21 +41,18 @@ final class APIRequestV2Tests: XCTestCase { cachePolicy: cachePolicy, responseConstraints: constraints) - guard let urlRequest = apiRequest?.urlRequest else { - XCTFail("Nil URLRequest") - return - } + let urlRequest = apiRequest.urlRequest XCTAssertEqual(urlRequest.url?.host(), url.host()) XCTAssertEqual(urlRequest.httpMethod, method.rawValue) let urlComponents = URLComponents(string: urlRequest.url!.absoluteString)! - XCTAssertTrue(urlComponents.queryItems!.contains(queryItems.toURLQueryItems())) + XCTAssertTrue(urlComponents.queryItems!.contains(URLQueryItem(name: "key", value: "value"))) XCTAssertEqual(urlRequest.allHTTPHeaderFields, headers.httpHeaders) XCTAssertEqual(urlRequest.httpBody, body) - XCTAssertEqual(apiRequest?.timeoutInterval, timeoutInterval) + XCTAssertEqual(apiRequest.timeoutInterval, timeoutInterval) XCTAssertEqual(urlRequest.cachePolicy, cachePolicy) - XCTAssertEqual(apiRequest?.responseConstraints, constraints) + XCTAssertEqual(apiRequest.responseConstraints, constraints) } func testURLRequestGeneration() { @@ -75,16 +72,16 @@ final class APIRequestV2Tests: XCTestCase { timeoutInterval: timeoutInterval, cachePolicy: cachePolicy) - let urlComponents = URLComponents(string: apiRequest!.urlRequest.url!.absoluteString)! - XCTAssertTrue(urlComponents.queryItems!.contains(queryItems.toURLQueryItems())) + let urlComponents = URLComponents(string: apiRequest.urlRequest.url!.absoluteString)! + XCTAssertTrue(urlComponents.queryItems!.contains(URLQueryItem(name: "key", value: "value"))) XCTAssertNotNil(apiRequest) - XCTAssertEqual(apiRequest?.urlRequest.url?.absoluteString, "https://www.example.com?key=value") - XCTAssertEqual(apiRequest?.urlRequest.httpMethod, method.rawValue) - XCTAssertEqual(apiRequest?.urlRequest.allHTTPHeaderFields, headers.httpHeaders) - XCTAssertEqual(apiRequest?.urlRequest.httpBody, body) - XCTAssertEqual(apiRequest?.urlRequest.timeoutInterval, timeoutInterval) - XCTAssertEqual(apiRequest?.urlRequest.cachePolicy, cachePolicy) + XCTAssertEqual(apiRequest.urlRequest.url?.absoluteString, "https://www.example.com?key=value") + XCTAssertEqual(apiRequest.urlRequest.httpMethod, method.rawValue) + XCTAssertEqual(apiRequest.urlRequest.allHTTPHeaderFields, headers.httpHeaders) + XCTAssertEqual(apiRequest.urlRequest.httpBody, body) + XCTAssertEqual(apiRequest.urlRequest.timeoutInterval, timeoutInterval) + XCTAssertEqual(apiRequest.urlRequest.cachePolicy, cachePolicy) } func testDefaultValues() { @@ -92,16 +89,13 @@ final class APIRequestV2Tests: XCTestCase { let apiRequest = APIRequestV2(url: url) let headers = APIRequestV2.HeadersV2() - guard let urlRequest = apiRequest?.urlRequest else { - XCTFail("Nil URLRequest") - return - } + let urlRequest = apiRequest.urlRequest XCTAssertEqual(urlRequest.httpMethod, HTTPRequestMethod.get.rawValue) XCTAssertEqual(urlRequest.timeoutInterval, 60.0) XCTAssertEqual(headers.httpHeaders, urlRequest.allHTTPHeaderFields) XCTAssertNil(urlRequest.httpBody) XCTAssertEqual(urlRequest.cachePolicy.rawValue, 0) - XCTAssertNil(apiRequest?.responseConstraints) + XCTAssertNil(apiRequest.responseConstraints) } func testAllowedQueryReservedCharacters() { @@ -112,9 +106,10 @@ final class APIRequestV2Tests: XCTestCase { queryItems: queryItems, allowedQueryReservedCharacters: CharacterSet(charactersIn: ",")) - let urlString = apiRequest!.urlRequest.url!.absoluteString - XCTAssertTrue(urlString == "https://www.example.com?k%2523e,y=val%2523ue") + let urlString = apiRequest.urlRequest.url!.absoluteString + XCTAssertEqual(urlString, "https://www.example.com?k%23e,y=val%23ue") + let urlComponents = URLComponents(string: urlString)! - XCTAssertTrue(urlComponents.queryItems?.count == 1) + XCTAssertEqual(urlComponents.queryItems?.count, 1) } } diff --git a/Tests/NetworkingTests/v2/APIServiceTests.swift b/Tests/NetworkingTests/v2/APIServiceTests.swift index 9cae44323..730d6afbb 100644 --- a/Tests/NetworkingTests/v2/APIServiceTests.swift +++ b/Tests/NetworkingTests/v2/APIServiceTests.swift @@ -41,7 +41,7 @@ final class APIServiceTests: XCTestCase { cachePolicy: .reloadIgnoringLocalAndRemoteCacheData, responseConstraints: [APIResponseConstraints.allowHTTPNotModified, APIResponseConstraints.requireETagHeader], - allowedQueryReservedCharacters: CharacterSet(charactersIn: ","))! + allowedQueryReservedCharacters: CharacterSet(charactersIn: ",")) let apiService = DefaultAPIService() let response = try await apiService.fetch(request: request) let responseHTML: String = try response.decodeBody() @@ -50,7 +50,7 @@ final class APIServiceTests: XCTestCase { func disabled_testRealCallJSON() async throws { // func testRealCallJSON() async throws { - let request = APIRequestV2(url: HTTPURLResponse.testUrl)! + let request = APIRequestV2(url: HTTPURLResponse.testUrl) let apiService = DefaultAPIService() let result = try await apiService.fetch(request: request) @@ -63,7 +63,7 @@ final class APIServiceTests: XCTestCase { func disabled_testRealCallString() async throws { // func testRealCallString() async throws { - let request = APIRequestV2(url: HTTPURLResponse.testUrl)! + let request = APIRequestV2(url: HTTPURLResponse.testUrl) let apiService = DefaultAPIService() let result = try await apiService.fetch(request: request) @@ -75,17 +75,16 @@ final class APIServiceTests: XCTestCase { "qName2": "qValue2"] MockURLProtocol.requestHandler = { request in let urlComponents = URLComponents(string: request.url!.absoluteString)! - XCTAssertTrue(urlComponents.queryItems!.contains(qItems.toURLQueryItems())) + XCTAssertTrue(urlComponents.queryItems!.contains(qItems.map { URLQueryItem(name: $0.key, value: $0.value) })) return (HTTPURLResponse.ok, nil) } - let request = APIRequestV2(url: HTTPURLResponse.testUrl, - queryItems: qItems)! + let request = APIRequestV2(url: HTTPURLResponse.testUrl, queryItems: qItems) let apiService = DefaultAPIService(urlSession: mockURLSession) _ = try await apiService.fetch(request: request) } func testURLRequestError() async throws { - let request = APIRequestV2(url: HTTPURLResponse.testUrl)! + let request = APIRequestV2(url: HTTPURLResponse.testUrl) enum TestError: Error { case anError @@ -111,7 +110,7 @@ final class APIServiceTests: XCTestCase { func testResponseRequirementAllowHTTPNotModifiedSuccess() async throws { let requirements = [APIResponseConstraints.allowHTTPNotModified ] - let request = APIRequestV2(url: HTTPURLResponse.testUrl, responseConstraints: requirements)! + let request = APIRequestV2(url: HTTPURLResponse.testUrl, responseConstraints: requirements) MockURLProtocol.requestHandler = { _ in ( HTTPURLResponse.notModified, Data()) } @@ -122,7 +121,7 @@ final class APIServiceTests: XCTestCase { } func testResponseRequirementAllowHTTPNotModifiedFailure() async throws { - let request = APIRequestV2(url: HTTPURLResponse.testUrl)! + let request = APIRequestV2(url: HTTPURLResponse.testUrl) MockURLProtocol.requestHandler = { _ in ( HTTPURLResponse.notModified, Data()) } @@ -147,7 +146,7 @@ final class APIServiceTests: XCTestCase { let requirements: [APIResponseConstraints] = [ APIResponseConstraints.requireETagHeader ] - let request = APIRequestV2(url: HTTPURLResponse.testUrl, responseConstraints: requirements)! + let request = APIRequestV2(url: HTTPURLResponse.testUrl, responseConstraints: requirements) MockURLProtocol.requestHandler = { _ in ( HTTPURLResponse.ok, nil) } // HTTPURLResponse.ok contains etag let apiService = DefaultAPIService(urlSession: mockURLSession) @@ -158,7 +157,7 @@ final class APIServiceTests: XCTestCase { func testResponseRequirementRequireETagHeaderFailure() async throws { let requirements = [ APIResponseConstraints.requireETagHeader ] - let request = APIRequestV2(url: HTTPURLResponse.testUrl, responseConstraints: requirements)! + let request = APIRequestV2(url: HTTPURLResponse.testUrl, responseConstraints: requirements) MockURLProtocol.requestHandler = { _ in ( HTTPURLResponse.okNoEtag, nil) } @@ -181,7 +180,7 @@ final class APIServiceTests: XCTestCase { func testResponseRequirementRequireUserAgentSuccess() async throws { let requirements = [ APIResponseConstraints.requireUserAgent ] - let request = APIRequestV2(url: HTTPURLResponse.testUrl, responseConstraints: requirements)! + let request = APIRequestV2(url: HTTPURLResponse.testUrl, responseConstraints: requirements) MockURLProtocol.requestHandler = { _ in ( HTTPURLResponse.okUserAgent, nil) @@ -194,7 +193,7 @@ final class APIServiceTests: XCTestCase { func testResponseRequirementRequireUserAgentFailure() async throws { let requirements = [ APIResponseConstraints.requireUserAgent ] - let request = APIRequestV2(url: HTTPURLResponse.testUrl, responseConstraints: requirements)! + let request = APIRequestV2(url: HTTPURLResponse.testUrl, responseConstraints: requirements) MockURLProtocol.requestHandler = { _ in ( HTTPURLResponse.ok, nil) } From 9a32e36c267cf135ae5a48636dadbfeec23da179 Mon Sep 17 00:00:00 2001 From: Alexey Martemyanov Date: Mon, 25 Nov 2024 21:55:58 +0600 Subject: [PATCH 02/15] fix linter issues --- .../MaliciousSiteProtectionAPIClientTests.swift | 2 -- .../Mocks/MockMaliciousSiteProtectionAPIClient.swift | 2 +- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/Tests/MaliciousSiteProtectionTests/MaliciousSiteProtectionAPIClientTests.swift b/Tests/MaliciousSiteProtectionTests/MaliciousSiteProtectionAPIClientTests.swift index e1fd6db4f..d32d264ea 100644 --- a/Tests/MaliciousSiteProtectionTests/MaliciousSiteProtectionAPIClientTests.swift +++ b/Tests/MaliciousSiteProtectionTests/MaliciousSiteProtectionAPIClientTests.swift @@ -124,7 +124,6 @@ final class MaliciousSiteProtectionAPIClientTests: XCTestCase { } } - func testWhenMatchesRequestFails_ErrorThrown() async throws { // Given let invalidHashPrefix = "" @@ -142,4 +141,3 @@ final class MaliciousSiteProtectionAPIClientTests: XCTestCase { } } - diff --git a/Tests/MaliciousSiteProtectionTests/Mocks/MockMaliciousSiteProtectionAPIClient.swift b/Tests/MaliciousSiteProtectionTests/Mocks/MockMaliciousSiteProtectionAPIClient.swift index d47552491..24d1c203b 100644 --- a/Tests/MaliciousSiteProtectionTests/Mocks/MockMaliciousSiteProtectionAPIClient.swift +++ b/Tests/MaliciousSiteProtectionTests/Mocks/MockMaliciousSiteProtectionAPIClient.swift @@ -65,7 +65,7 @@ public class MockMaliciousSiteProtectionAPIClient: MaliciousSiteProtection.APICl ], revision: 4, replace: false) ] - public func load(_ requestConfig: Request) async throws -> Request.ResponseType where Request : APIRequestProtocol { + public func load(_ requestConfig: Request) async throws -> Request.ResponseType where Request: APIRequestProtocol { switch requestConfig.requestType { case .hashPrefixSet(let configuration): return _hashPrefixesChangeSet(for: configuration.threatKind, revision: configuration.revision ?? 0) as! Request.ResponseType From b2bdcc7fedf176fbf167d17bd76e25e488e71b86 Mon Sep 17 00:00:00 2001 From: Alexey Martemyanov Date: Mon, 25 Nov 2024 22:02:49 +0600 Subject: [PATCH 03/15] fix failing test --- Sources/Common/Extensions/URLExtension.swift | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/Sources/Common/Extensions/URLExtension.swift b/Sources/Common/Extensions/URLExtension.swift index 78cc2881f..ce68773d5 100644 --- a/Sources/Common/Extensions/URLExtension.swift +++ b/Sources/Common/Extensions/URLExtension.swift @@ -380,7 +380,8 @@ extension URL { } public func appending(percentEncodedQueryItems: [URLQueryItem]) -> URL { - guard var components = URLComponents(url: self, resolvingAgainstBaseURL: true) else { return self } + guard !percentEncodedQueryItems.isEmpty, + var components = URLComponents(url: self, resolvingAgainstBaseURL: true) else { return self } var existingPercentEncodedQueryItems = components.percentEncodedQueryItems ?? [URLQueryItem]() existingPercentEncodedQueryItems.append(contentsOf: percentEncodedQueryItems) From ed8ebb931dfaa969a6fdf0fa6da996d4d5e2f7ea Mon Sep 17 00:00:00 2001 From: Alexey Martemyanov Date: Tue, 26 Nov 2024 19:13:39 +0600 Subject: [PATCH 04/15] Refactor Data storing --- .../Common/Extensions/StringExtension.swift | 8 +- .../API/APIRequest.swift | 17 +- .../API/ChangeSetResponse.swift | 4 + .../MaliciousSiteDetector.swift | 116 ++++----- .../Model/FilterDictionary.swift | 68 ++++++ .../Model/HashPrefixSet.swift | 44 ++++ ...entallyUpdatableMaliciousSiteDataSet.swift | 70 ++++++ .../Model/LoadableFromEmbeddedData.swift | 34 +++ .../Model/StoredData.swift | 97 ++++++++ .../Services/DataManager.swift | 183 +++++---------- .../Services/EmbeddedDataProvider.swift | 63 ++--- .../Services/FileStore.swift | 23 +- .../Services/UpdateManager.swift | 73 +++--- .../MaliciousSiteDetectorTests.swift | 26 +-- ...iciousSiteProtectionDataManagerTests.swift | 221 ++++++++++-------- ...teProtectionEmbeddedDataProviderTest.swift | 46 ++-- ...iousSiteProtectionUpdateManagerTests.swift | 117 +++++++--- ...MockMaliciousSiteProtectionAPIClient.swift | 26 ++- ...ckMaliciousSiteProtectionDataManager.swift | 21 +- ...usSiteProtectionEmbeddedDataProvider.swift | 33 +-- .../MockPhishingDetectionUpdateManager.swift | 8 + 21 files changed, 814 insertions(+), 484 deletions(-) create mode 100644 Sources/MaliciousSiteProtection/Model/FilterDictionary.swift create mode 100644 Sources/MaliciousSiteProtection/Model/HashPrefixSet.swift create mode 100644 Sources/MaliciousSiteProtection/Model/IncrementallyUpdatableMaliciousSiteDataSet.swift create mode 100644 Sources/MaliciousSiteProtection/Model/LoadableFromEmbeddedData.swift create mode 100644 Sources/MaliciousSiteProtection/Model/StoredData.swift diff --git a/Sources/Common/Extensions/StringExtension.swift b/Sources/Common/Extensions/StringExtension.swift index 09050cfe2..9282a43b4 100644 --- a/Sources/Common/Extensions/StringExtension.swift +++ b/Sources/Common/Extensions/StringExtension.swift @@ -394,9 +394,9 @@ public extension String { // MARK: Regex - func matches(_ regex: NSRegularExpression) -> Bool { - let matches = regex.matches(in: self, options: .anchored, range: self.fullRange) - return matches.count == 1 + func matches(_ regex: RegEx) -> Bool { + let firstMatch = firstMatch(of: regex, options: .anchored) + return firstMatch != nil } func matches(pattern: String, options: NSRegularExpression.Options = [.caseInsensitive]) -> Bool { @@ -406,7 +406,7 @@ public extension String { return matches(regex) } - func replacing(_ regex: NSRegularExpression, with replacement: String) -> String { + func replacing(_ regex: RegEx, with replacement: String) -> String { regex.stringByReplacingMatches(in: self, range: self.fullRange, withTemplate: replacement) } diff --git a/Sources/MaliciousSiteProtection/API/APIRequest.swift b/Sources/MaliciousSiteProtection/API/APIRequest.swift index af18fb2f2..0168b01cf 100644 --- a/Sources/MaliciousSiteProtection/API/APIRequest.swift +++ b/Sources/MaliciousSiteProtection/API/APIRequest.swift @@ -22,6 +22,9 @@ public protocol APIRequestProtocol { associatedtype ResponseType: Decodable var requestType: APIClient.Request { get } } +public protocol MaliciousSiteDataChangeSetAPIRequestProtocol: APIRequestProtocol { + init(threatKind: ThreatKind, revision: Int?) +} public extension APIClient { enum Request { @@ -31,12 +34,17 @@ public extension APIClient { } } public extension APIClient.Request { - struct HashPrefixes: APIRequestProtocol { + struct HashPrefixes: MaliciousSiteDataChangeSetAPIRequestProtocol { public typealias ResponseType = APIClient.Response.HashPrefixesChangeSet public let threatKind: ThreatKind public let revision: Int? + public init(threatKind: ThreatKind, revision: Int?) { + self.threatKind = threatKind + self.revision = revision + } + public var requestType: APIClient.Request { .hashPrefixSet(self) } @@ -49,12 +57,17 @@ extension APIRequestProtocol where Self == APIClient.Request.HashPrefixes { } public extension APIClient.Request { - struct FilterSet: APIRequestProtocol { + struct FilterSet: MaliciousSiteDataChangeSetAPIRequestProtocol { public typealias ResponseType = APIClient.Response.FiltersChangeSet public let threatKind: ThreatKind public let revision: Int? + public init(threatKind: ThreatKind, revision: Int?) { + self.threatKind = threatKind + self.revision = revision + } + public var requestType: APIClient.Request { .filterSet(self) } diff --git a/Sources/MaliciousSiteProtection/API/ChangeSetResponse.swift b/Sources/MaliciousSiteProtection/API/ChangeSetResponse.swift index 7988c2a32..eaf4f287c 100644 --- a/Sources/MaliciousSiteProtection/API/ChangeSetResponse.swift +++ b/Sources/MaliciousSiteProtection/API/ChangeSetResponse.swift @@ -32,6 +32,10 @@ extension APIClient { self.revision = revision self.replace = replace } + + public var isEmpty: Bool { + insert.isEmpty && delete.isEmpty + } } public enum Response { diff --git a/Sources/MaliciousSiteProtection/MaliciousSiteDetector.swift b/Sources/MaliciousSiteProtection/MaliciousSiteDetector.swift index 592e7e852..206524b9a 100644 --- a/Sources/MaliciousSiteProtection/MaliciousSiteDetector.swift +++ b/Sources/MaliciousSiteProtection/MaliciousSiteDetector.swift @@ -24,16 +24,21 @@ public protocol MaliciousSiteDetecting { func evaluate(_ url: URL) async -> ThreatKind? } +/// Class responsible for detecting malicious sites by evaluating URLs against local filters and an external API. +/// entry point: `func evaluate(_: URL) async -> ThreatKind?` public final class MaliciousSiteDetector: MaliciousSiteDetecting { - // for easier Xcode symbol navigation + // Type aliases for easier symbol navigation in Xcode. typealias PhishingDetector = MaliciousSiteDetector typealias MalwareDetector = MaliciousSiteDetector - let hashPrefixStoreLength: Int = 8 - let hashPrefixParamLength: Int = 4 - let apiClient: APIClientProtocol - let dataManager: DataManaging - let eventMapping: EventMapping + private enum Constants { + static let hashPrefixStoreLength: Int = 8 + static let hashPrefixParamLength: Int = 4 + } + + private let apiClient: APIClientProtocol + private let dataManager: DataManaging + private let eventMapping: EventMapping public init(apiClient: APIClientProtocol = APIClient(), dataManager: DataManaging, eventMapping: EventMapping) { self.apiClient = apiClient @@ -41,73 +46,70 @@ public final class MaliciousSiteDetector: MaliciousSiteDetecting { self.eventMapping = eventMapping } - private func inFilterSet(hash: String) -> Set { - return Set(dataManager.filterSet.filter { $0.hash == hash }) - } - - private func matchesUrl(hash: String, regexPattern: String, url: URL, hostnameHash: String) -> Bool { - if hash == hostnameHash, - let regex = try? NSRegularExpression(pattern: regexPattern, options: []) { - let urlString = url.absoluteString - let range = NSRange(location: 0, length: urlString.utf16.count) - return regex.firstMatch(in: urlString, options: [], range: range) != nil - } - return false - } - private func generateHashPrefix(for canonicalHost: String, length: Int) -> String { let hostnameHash = SHA256.hash(data: Data(canonicalHost.utf8)).map { String(format: "%02hhx", $0) }.joined() return String(hostnameHash.prefix(length)) } - private func fetchMatches(hashPrefix: String) async -> [Match] { - do { - let response = try await apiClient.matches(forHashPrefix: hashPrefix) - return response.matches - } catch { - Logger.api.error("Failed to fetch matches for hash prefix: \(hashPrefix): \(error.localizedDescription)") - return [] - } + private func checkLocalFilters(canonicalHost: String, canonicalUrl: URL, for threatKind: ThreatKind) async -> Bool { + let filterSet = await dataManager.dataSet(for: .filterSet(threatKind: threatKind)) + let hostnameHash = generateHashPrefix(for: canonicalHost, length: Int.max) + let matchesLocalFilters = filterSet[hostnameHash]?.contains(where: { regex in + canonicalUrl.absoluteString.matches(pattern: regex) + }) ?? false + + return matchesLocalFilters } - private func checkLocalFilters(canonicalHost: String, canonicalUrl: URL) -> Bool { - let hostnameHash = generateHashPrefix(for: canonicalHost, length: Int.max) - let filterHit = inFilterSet(hash: hostnameHash) - for filter in filterHit where matchesUrl(hash: filter.hash, regexPattern: filter.regex, url: canonicalUrl, hostnameHash: hostnameHash) { - eventMapping.fire(.errorPageShown(clientSideHit: true)) - return true + private func checkApiMatches(canonicalHost: String, canonicalUrl: URL) async -> Match? { + let hashPrefixParam = generateHashPrefix(for: canonicalHost, length: Constants.hashPrefixParamLength) + let matches: [Match] + do { + matches = try await apiClient.matches(forHashPrefix: hashPrefixParam).matches + } catch { + Logger.general.error("Error fetching matches from API: \(error)") + return nil } - return false - } - private func checkApiMatches(canonicalHost: String, canonicalUrl: URL) async -> Bool { - let hashPrefixParam = generateHashPrefix(for: canonicalHost, length: hashPrefixParamLength) - let matches = await fetchMatches(hashPrefix: hashPrefixParam) - let hostnameHash = generateHashPrefix(for: canonicalHost, length: Int.max) - for match in matches where matchesUrl(hash: match.hash, regexPattern: match.regex, url: canonicalUrl, hostnameHash: hostnameHash) { - eventMapping.fire(.errorPageShown(clientSideHit: false)) - return true + if let match = matches.first(where: { match in + canonicalUrl.absoluteString.matches(pattern: match.regex) + }) { + return match } - return false + return nil } + /// Evaluates the given URL to determine its threat level. + /// - Parameter url: The URL to evaluate. + /// - Returns: An optional ThreatKind indicating the type of threat, or nil if no threat is detected. public func evaluate(_ url: URL) async -> ThreatKind? { - guard let canonicalHost = url.canonicalHost(), let canonicalUrl = url.canonicalURL() else { return .none } - - for threatKind in ThreatKind.allCases { - let hashPrefix = generateHashPrefix(for: canonicalHost, length: hashPrefixStoreLength) - if dataManager.hashPrefixes.contains(hashPrefix) { - // Check local filterSet first - if checkLocalFilters(canonicalHost: canonicalHost, canonicalUrl: canonicalUrl) { - return threatKind - } - // If nothing found, hit the API to get matches - if await checkApiMatches(canonicalHost: canonicalHost, canonicalUrl: canonicalUrl) { - return threatKind - } + guard let canonicalHost = url.canonicalHost(), + let canonicalUrl = url.canonicalURL() else { return .none } + + let hashPrefix = generateHashPrefix(for: canonicalHost, length: Constants.hashPrefixStoreLength) + + for threatKind in ThreatKind.allCases /* phishing, malware.. */ { + let hashPrefixes = await dataManager.dataSet(for: .hashPrefixes(threatKind: threatKind)) + guard hashPrefixes.contains(hashPrefix) else { continue } + + // Check local filterSet first + if await checkLocalFilters(canonicalHost: canonicalHost, canonicalUrl: canonicalUrl, for: threatKind) { + eventMapping.fire(.errorPageShown(clientSideHit: true)) + return threatKind + } + + // If nothing found, hit the API to get matches + let match = await checkApiMatches(canonicalHost: canonicalHost, canonicalUrl: canonicalUrl) + if let match { + eventMapping.fire(.errorPageShown(clientSideHit: false)) + return match.category.map(ThreatKind.init) ?? threatKind } + + // the API detects both phishing and malware so if it didn‘t find any matches it‘s safe to return early. + return nil } return .none } + } diff --git a/Sources/MaliciousSiteProtection/Model/FilterDictionary.swift b/Sources/MaliciousSiteProtection/Model/FilterDictionary.swift new file mode 100644 index 000000000..79cb6632e --- /dev/null +++ b/Sources/MaliciousSiteProtection/Model/FilterDictionary.swift @@ -0,0 +1,68 @@ +// +// FilterDictionary.swift +// +// Copyright © 2024 DuckDuckGo. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +import Foundation + +public struct FilterDictionary: Codable, Equatable { + + /// Filter set revision + public var revision: Int + + /// [Hash: [RegEx]] mapping + /// + /// - **Key**: SHA256 hash sum of a canonical host name + /// - **Value**: An array of regex patterns used to match whole URLs + /// + /// ``` + /// { + /// "3aeb002460381c6f258e8395d3026f571f0d9a76488dcd837639b13aed316560" : [ + /// "(?i)^https?\\:\\/\\/[\\w\\-\\.]+(?:\\:(?:80|443))?[\\/\\\\]+BETS1O\\-GIRIS[\\/\\\\]+BETS1O(?:[\\/\\\\]+|\\?|$)" + /// ], + /// ... + /// } + /// ``` + public var filters: [String: Set] + + init(revision: Int, filters: [String: Set]) { + self.filters = filters + self.revision = revision + } + + /// Subscript to access regex patterns by SHA256 host name hash + subscript(hash: String) -> Set? { + filters[hash] + } + + public mutating func subtract(_ itemsToDelete: Seq) where Seq.Element == Filter { + for filter in itemsToDelete { + withUnsafeMutablePointer(to: &filters[filter.hash]) { item in + item.pointee?.remove(filter.regex) + if item.pointee?.isEmpty == true { + item.pointee = nil // TODO: Validate deallocation + } + } + } + } + + public mutating func formUnion(_ itemsToAdd: Seq) where Seq.Element == Filter { + for filter in itemsToAdd { + filters[filter.hash, default: []].insert(filter.regex) + } + } + +} diff --git a/Sources/MaliciousSiteProtection/Model/HashPrefixSet.swift b/Sources/MaliciousSiteProtection/Model/HashPrefixSet.swift new file mode 100644 index 000000000..06a5fc7fa --- /dev/null +++ b/Sources/MaliciousSiteProtection/Model/HashPrefixSet.swift @@ -0,0 +1,44 @@ +// +// HashPrefixSet.swift +// +// Copyright © 2024 DuckDuckGo. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +import Foundation + +public struct HashPrefixSet: Codable, Equatable { + + public var revision: Int + public var set: Set + + public init(revision: Int, items: some Sequence) { + self.revision = revision + self.set = Set(items) + } + + public mutating func subtract(_ itemsToDelete: Seq) where Seq.Element == String { + set.subtract(itemsToDelete) + } + + public mutating func formUnion(_ itemsToAdd: Seq) where Seq.Element == String { + set.formUnion(itemsToAdd) + } + + @inline(__always) + func contains(_ item: String) -> Bool { + set.contains(item) + } + +} diff --git a/Sources/MaliciousSiteProtection/Model/IncrementallyUpdatableMaliciousSiteDataSet.swift b/Sources/MaliciousSiteProtection/Model/IncrementallyUpdatableMaliciousSiteDataSet.swift new file mode 100644 index 000000000..773659cdb --- /dev/null +++ b/Sources/MaliciousSiteProtection/Model/IncrementallyUpdatableMaliciousSiteDataSet.swift @@ -0,0 +1,70 @@ +// +// IncrementallyUpdatableMaliciousSiteDataSet.swift +// +// Copyright © 2024 DuckDuckGo. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +public protocol IncrementallyUpdatableMaliciousSiteDataSet: Codable { + /// Set Element Type (Hash Prefix or Filter) + associatedtype Element: Codable, Hashable + associatedtype APIRequestType: MaliciousSiteDataChangeSetAPIRequestProtocol, APIRequestProtocol where APIRequestType.ResponseType == APIClient.ChangeSetResponse + + var revision: Int { get set } + + init(revision: Int, items: some Sequence) + + mutating func subtract(_ itemsToDelete: Seq) where Seq.Element == Element + mutating func formUnion(_ itemsToAdd: Seq) where Seq.Element == Element + + /// Apply ChangeSet from local data revision to actual revision loaded from API + mutating func apply(_ changeSet: APIClient.ChangeSetResponse) +} + +extension IncrementallyUpdatableMaliciousSiteDataSet { + public mutating func apply(_ changeSet: APIClient.ChangeSetResponse) { + if changeSet.replace { + self = .init(revision: changeSet.revision, items: changeSet.insert) + } else { + self.subtract(changeSet.delete) + self.formUnion(changeSet.insert) + self.revision = changeSet.revision + } + } +} + +extension HashPrefixSet: IncrementallyUpdatableMaliciousSiteDataSet { + public typealias Element = String + public typealias APIRequestType = APIClient.Request.HashPrefixes + + public static func apiRequest(for threatKind: ThreatKind, revision: Int) -> APIRequestType { + .hashPrefixes(threatKind: threatKind, revision: revision) + } +} + +extension FilterDictionary: IncrementallyUpdatableMaliciousSiteDataSet { + public typealias Element = Filter + public typealias APIRequestType = APIClient.Request.FilterSet + + public init(revision: Int, items: some Sequence) { + let filtersDictionary = items.reduce(into: [String: Set]()) { result, filter in + result[filter.hash, default: []].insert(filter.regex) + } + self.init(revision: revision, filters: filtersDictionary) + } + + public static func apiRequest(for threatKind: ThreatKind, revision: Int) -> APIRequestType { + .filterSet(threatKind: threatKind, revision: revision) + } +} diff --git a/Sources/MaliciousSiteProtection/Model/LoadableFromEmbeddedData.swift b/Sources/MaliciousSiteProtection/Model/LoadableFromEmbeddedData.swift new file mode 100644 index 000000000..d6eddce09 --- /dev/null +++ b/Sources/MaliciousSiteProtection/Model/LoadableFromEmbeddedData.swift @@ -0,0 +1,34 @@ +// +// LoadableFromEmbeddedData.swift +// +// Copyright © 2024 DuckDuckGo. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +public protocol LoadableFromEmbeddedData { + /// Set Element Type (Hash Prefix or Filter) + associatedtype Element + /// Decoded data type stored in the embedded json file + associatedtype EmbeddedDataSetType: Decodable, Sequence where EmbeddedDataSetType.Element == Self.Element + + init(revision: Int, items: some Sequence) +} + +extension HashPrefixSet: LoadableFromEmbeddedData { + public typealias EmbeddedDataSetType = [String] +} + +extension FilterDictionary: LoadableFromEmbeddedData { + public typealias EmbeddedDataSetType = [Filter] +} diff --git a/Sources/MaliciousSiteProtection/Model/StoredData.swift b/Sources/MaliciousSiteProtection/Model/StoredData.swift new file mode 100644 index 000000000..041fa7390 --- /dev/null +++ b/Sources/MaliciousSiteProtection/Model/StoredData.swift @@ -0,0 +1,97 @@ +// +// StoredData.swift +// +// Copyright © 2024 DuckDuckGo. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +import Foundation + +public protocol MaliciousSiteDataKeyProtocol: Hashable { + associatedtype EmbeddedDataSetType: Decodable + associatedtype DataSetType: IncrementallyUpdatableMaliciousSiteDataSet, LoadableFromEmbeddedData + + var dataType: DataManager.StoredDataType { get } + var threatKind: ThreatKind { get } +} + +public extension DataManager { + enum StoredDataType: Hashable, CaseIterable { + case hashPrefixSet(HashPrefixes) + case filterSet(FilterSet) + + enum Kind: CaseIterable { + case hashPrefixSet, filterSet + } + // keep to get a compiler error when number of cases changes + var kind: Kind { + switch self { + case .hashPrefixSet: .hashPrefixSet + case .filterSet: .filterSet + } + } + + var dataKey: any MaliciousSiteDataKeyProtocol { + switch self { + case .hashPrefixSet(let key): key + case .filterSet(let key): key + } + } + + public static var allCases: [DataManager.StoredDataType] { + ThreatKind.allCases.map { threatKind in + Kind.allCases.map { dataKind in + switch dataKind { + case .hashPrefixSet: .hashPrefixSet(.init(threatKind: threatKind)) + case .filterSet: .filterSet(.init(threatKind: threatKind)) + } + } + }.flatMap { $0 } + } + } +} + +public extension DataManager.StoredDataType { + struct HashPrefixes: MaliciousSiteDataKeyProtocol { + public typealias DataSetType = HashPrefixSet + + public let threatKind: ThreatKind + + public var dataType: DataManager.StoredDataType { + .hashPrefixSet(self) + } + } +} +extension MaliciousSiteDataKeyProtocol where Self == DataManager.StoredDataType.HashPrefixes { + static func hashPrefixes(threatKind: ThreatKind) -> Self { + .init(threatKind: threatKind) + } +} + +public extension DataManager.StoredDataType { + struct FilterSet: MaliciousSiteDataKeyProtocol { + public typealias DataSetType = FilterDictionary + + public let threatKind: ThreatKind + + public var dataType: DataManager.StoredDataType { + .filterSet(self) + } + } +} +extension MaliciousSiteDataKeyProtocol where Self == DataManager.StoredDataType.FilterSet { + static func filterSet(threatKind: ThreatKind) -> Self { + .init(threatKind: threatKind) + } +} diff --git a/Sources/MaliciousSiteProtection/Services/DataManager.swift b/Sources/MaliciousSiteProtection/Services/DataManager.swift index 41b2dba9d..0e2154648 100644 --- a/Sources/MaliciousSiteProtection/Services/DataManager.swift +++ b/Sources/MaliciousSiteProtection/Services/DataManager.swift @@ -17,164 +17,89 @@ // import Foundation -import Common import os public protocol DataManaging { - var filterSet: Set { get } - var hashPrefixes: Set { get } - var currentRevision: Int { get } - func saveFilterSet(set: Set) - func saveHashPrefixes(set: Set) - func saveRevision(_ revision: Int) + func dataSet(for key: DataKey) async -> DataKey.DataSetType + func store(_ dataSet: DataKey.DataSetType, for key: DataKey) async } -public final class DataManager: DataManaging { - private lazy var _filterSet: Set = { - loadFilterSet() - }() - - private lazy var _hashPrefixes: Set = { - loadHashPrefix() - }() - - private lazy var _currentRevision: Int = { - loadRevision() - }() - - public private(set) var filterSet: Set { - get { _filterSet } - set { _filterSet = newValue } - } - public private(set) var hashPrefixes: Set { - get { _hashPrefixes } - set { _hashPrefixes = newValue } - } - public private(set) var currentRevision: Int { - get { _currentRevision } - set { _currentRevision = newValue } - } +public actor DataManager: DataManaging { private let embeddedDataProvider: EmbeddedDataProviding private let fileStore: FileStoring - private let encoder = JSONEncoder() - private let revisionFilename = "revision.txt" - private let hashPrefixFilename = "phishingHashPrefixes.json" - private let filterSetFilename = "phishingFilterSet.json" - public init(embeddedDataProvider: EmbeddedDataProviding, fileStore: FileStoring? = nil) { + public typealias FileNameProvider = (DataManager.StoredDataType) -> String + private nonisolated let fileNameProvider: FileNameProvider + + private var store: [StoredDataType: Any] = [:] + + public init(fileStore: FileStoring, embeddedDataProvider: EmbeddedDataProviding, fileNameProvider: @escaping FileNameProvider) { self.embeddedDataProvider = embeddedDataProvider - self.fileStore = fileStore ?? FileStore() + self.fileStore = fileStore + self.fileNameProvider = fileNameProvider } - private func writeHashPrefixes() { - let encoder = JSONEncoder() - do { - let hashPrefixesData = try encoder.encode(Array(hashPrefixes)) - fileStore.write(data: hashPrefixesData, to: hashPrefixFilename) - } catch { - Logger.dataManager.error("Error saving hash prefixes data: \(error.localizedDescription)") + public func dataSet(for key: DataKey) -> DataKey.DataSetType { + let dataType = key.dataType + // return cached dataSet if available + if let data = store[key.dataType] as? DataKey.DataSetType { + return data } - } - private func writeFilterSet() { - let encoder = JSONEncoder() - do { - let filterSetData = try encoder.encode(Array(filterSet)) - fileStore.write(data: filterSetData, to: filterSetFilename) - } catch { - Logger.dataManager.error("Error saving filter set data: \(error.localizedDescription)") - } - } + // read stored dataSet if it‘s newer than the embedded one + let dataSet = readStoredDataSet(for: key) ?? { + // no stored dataSet or the embedded one is newer + let embeddedRevision = embeddedDataProvider.revision(for: dataType) + let embeddedItems = embeddedDataProvider.loadDataSet(for: key) + return .init(revision: embeddedRevision, items: embeddedItems) + }() - private func writeRevision() { - let encoder = JSONEncoder() - do { - let revisionData = try encoder.encode(currentRevision) - fileStore.write(data: revisionData, to: revisionFilename) - } catch { - Logger.dataManager.error("Error saving revision data: \(error.localizedDescription)") - } - } + // cache + store[dataType] = dataSet - private func loadHashPrefix() -> Set { - guard let data = fileStore.read(from: hashPrefixFilename) else { - return embeddedDataProvider.loadEmbeddedHashPrefixes() - } - let decoder = JSONDecoder() - do { - if loadRevisionFromDisk() < embeddedDataProvider.embeddedRevision { - return embeddedDataProvider.loadEmbeddedHashPrefixes() - } - let onDiskHashPrefixes = Set(try decoder.decode(Set.self, from: data)) - return onDiskHashPrefixes - } catch { - Logger.dataManager.error("Error decoding \(self.hashPrefixFilename): \(error.localizedDescription)") - return embeddedDataProvider.loadEmbeddedHashPrefixes() - } + return dataSet } - private func loadFilterSet() -> Set { - guard let data = fileStore.read(from: filterSetFilename) else { - return embeddedDataProvider.loadEmbeddedFilterSet() - } - let decoder = JSONDecoder() + private func readStoredDataSet(for key: DataKey) -> DataKey.DataSetType? { + let dataType = key.dataType + let fileName = fileNameProvider(dataType) + guard let data = fileStore.read(from: fileName) else { return nil } + + let storedDataSet: DataKey.DataSetType do { - if loadRevisionFromDisk() < embeddedDataProvider.embeddedRevision { - return embeddedDataProvider.loadEmbeddedFilterSet() - } - let onDiskFilterSet = Set(try decoder.decode(Set.self, from: data)) - return onDiskFilterSet + storedDataSet = try JSONDecoder().decode(DataKey.DataSetType.self, from: data) } catch { - Logger.dataManager.error("Error decoding \(self.filterSetFilename): \(error.localizedDescription)") - return embeddedDataProvider.loadEmbeddedFilterSet() + Logger.dataManager.error("Error decoding \(fileName): \(error.localizedDescription)") + return nil } - } - private func loadRevisionFromDisk() -> Int { - guard let data = fileStore.read(from: revisionFilename) else { - return embeddedDataProvider.embeddedRevision - } - let decoder = JSONDecoder() - do { - return try decoder.decode(Int.self, from: data) - } catch { - Logger.dataManager.error("Error decoding \(self.revisionFilename): \(error.localizedDescription)") - return embeddedDataProvider.embeddedRevision + // compare to the embedded data revision + let embeddedDataRevision = embeddedDataProvider.revision(for: dataType) + guard storedDataSet.revision >= embeddedDataRevision else { + Logger.dataManager.error("Stored \(fileName) is outdated: revision: \(storedDataSet.revision), embedded revision: \(embeddedDataRevision).") + return nil } + + return storedDataSet } - private func loadRevision() -> Int { - guard let data = fileStore.read(from: revisionFilename) else { - return embeddedDataProvider.embeddedRevision - } - let decoder = JSONDecoder() + public func store(_ dataSet: DataKey.DataSetType, for key: DataKey) { + let dataType = key.dataType + let fileName = fileNameProvider(dataType) + self.store[dataType] = dataSet + + let data: Data do { - let loadedRevision = try decoder.decode(Int.self, from: data) - if loadedRevision < embeddedDataProvider.embeddedRevision { - return embeddedDataProvider.embeddedRevision - } - return loadedRevision + data = try JSONEncoder().encode(dataSet) } catch { - Logger.dataManager.error("Error decoding \(self.revisionFilename): \(error.localizedDescription)") - return embeddedDataProvider.embeddedRevision + Logger.dataManager.error("Error encoding \(fileName): \(error.localizedDescription)") + assertionFailure("Failed to store data to \(fileName): \(error)") + return } - } -} - -extension DataManager { - public func saveFilterSet(set: Set) { - self.filterSet = set - writeFilterSet() - } - public func saveHashPrefixes(set: Set) { - self.hashPrefixes = set - writeHashPrefixes() + let success = fileStore.write(data: data, to: fileName) + assert(success) } - public func saveRevision(_ revision: Int) { - self.currentRevision = revision - writeRevision() - } } diff --git a/Sources/MaliciousSiteProtection/Services/EmbeddedDataProvider.swift b/Sources/MaliciousSiteProtection/Services/EmbeddedDataProvider.swift index 5ca4d9f7c..084199c8b 100644 --- a/Sources/MaliciousSiteProtection/Services/EmbeddedDataProvider.swift +++ b/Sources/MaliciousSiteProtection/Services/EmbeddedDataProvider.swift @@ -18,56 +18,35 @@ import Foundation import CryptoKit -import Common -import os public protocol EmbeddedDataProviding { - var embeddedRevision: Int { get } - func loadEmbeddedFilterSet() -> Set - func loadEmbeddedHashPrefixes() -> Set -} - -public struct EmbeddedDataProvider: EmbeddedDataProviding { - public let embeddedRevision: Int - private let embeddedFilterSetURL: URL - private let embeddedFilterSetDataSHA: String - private let embeddedHashPrefixURL: URL - private let embeddedHashPrefixDataSHA: String + func revision(for dataType: DataManager.StoredDataType) -> Int + func url(for dataType: DataManager.StoredDataType) -> URL + func hash(for dataType: DataManager.StoredDataType) -> String - public init(revision: Int, filterSetURL: URL, filterSetDataSHA: String, hashPrefixURL: URL, hashPrefixDataSHA: String) { - embeddedFilterSetURL = filterSetURL - embeddedFilterSetDataSHA = filterSetDataSHA - embeddedHashPrefixURL = hashPrefixURL - embeddedHashPrefixDataSHA = hashPrefixDataSHA - embeddedRevision = revision - } + func loadDataSet(for key: DataKey) -> DataKey.EmbeddedDataSetType +} - private func loadData(from url: URL, expectedSHA: String) throws -> Data { - let data = try Data(contentsOf: url) - let sha256 = SHA256.hash(data: data) - let hashString = sha256.compactMap { String(format: "%02x", $0) }.joined() +extension EmbeddedDataProviding { - guard hashString == expectedSHA else { - throw NSError(domain: "PhishingDetectionDataProvider", code: 1001, userInfo: [NSLocalizedDescriptionKey: "SHA mismatch"]) + public func loadDataSet(for key: DataKey) -> DataKey.EmbeddedDataSetType { + let dataType = key.dataType + let url = url(for: dataType) + let data: Data + do { + data = try Data(contentsOf: url) +#if DEBUG + assert(data.sha256 == hash(for: dataType), "SHA mismatch for \(url.path)") +#endif + } catch { + fatalError("Could not load embedded data set at \(url.path): \(error)") } - return data - } - - public func loadEmbeddedFilterSet() -> Set { - do { - let filterSetData = try loadData(from: embeddedFilterSetURL, expectedSHA: embeddedFilterSetDataSHA) - return try JSONDecoder().decode(Set.self, from: filterSetData) - } catch { - fatalError("🔴 Error: SHA mismatch for filterSet JSON file. Expected \(self.embeddedFilterSetDataSHA)") - } - } - - public func loadEmbeddedHashPrefixes() -> Set { do { - let hashPrefixData = try loadData(from: embeddedHashPrefixURL, expectedSHA: embeddedHashPrefixDataSHA) - return try JSONDecoder().decode(Set.self, from: hashPrefixData) + let result = try JSONDecoder().decode(DataKey.EmbeddedDataSetType.self, from: data) + return result } catch { - fatalError("🔴 Error: SHA mismatch for hashPrefixes JSON file. Expected \(self.embeddedHashPrefixDataSHA)") + fatalError("Could not decode embedded data set at \(url.path): \(error)") } } + } diff --git a/Sources/MaliciousSiteProtection/Services/FileStore.swift b/Sources/MaliciousSiteProtection/Services/FileStore.swift index e0714401a..6cc03b7e7 100644 --- a/Sources/MaliciousSiteProtection/Services/FileStore.swift +++ b/Sources/MaliciousSiteProtection/Services/FileStore.swift @@ -20,22 +20,15 @@ import Foundation import os public protocol FileStoring { - func write(data: Data, to filename: String) + @discardableResult func write(data: Data, to filename: String) -> Bool func read(from filename: String) -> Data? } -public struct FileStore: FileStoring { +struct FileStore: FileStoring, CustomDebugStringConvertible { private let dataStoreURL: URL - public init() { - let dataStoreDirectory: URL - do { - dataStoreDirectory = try FileManager.default.url(for: .applicationSupportDirectory, in: .userDomainMask, appropriateFor: nil, create: true) - } catch { - Logger.dataManager.error("Error accessing application support directory: \(error.localizedDescription)") - dataStoreDirectory = FileManager.default.temporaryDirectory - } - dataStoreURL = dataStoreDirectory.appendingPathComponent(Bundle.main.bundleIdentifier!, isDirectory: true) + init(dataStoreURL: URL) { + self.dataStoreURL = dataStoreURL createDirectoryIfNeeded() } @@ -47,12 +40,14 @@ public struct FileStore: FileStoring { } } - public func write(data: Data, to filename: String) { + public func write(data: Data, to filename: String) -> Bool { let fileURL = dataStoreURL.appendingPathComponent(filename) do { try data.write(to: fileURL) + return true } catch { Logger.dataManager.error("Error writing to directory: \(error.localizedDescription)") + return false } } @@ -65,4 +60,8 @@ public struct FileStore: FileStoring { return nil } } + + var debugDescription: String { + return "<\(type(of: self)) - \"\(dataStoreURL.path)\">" + } } diff --git a/Sources/MaliciousSiteProtection/Services/UpdateManager.swift b/Sources/MaliciousSiteProtection/Services/UpdateManager.swift index af9f60e7d..0dd7d1106 100644 --- a/Sources/MaliciousSiteProtection/Services/UpdateManager.swift +++ b/Sources/MaliciousSiteProtection/Services/UpdateManager.swift @@ -21,6 +21,8 @@ import Common import os public protocol UpdateManaging { + func updateData(for key: some MaliciousSiteDataKeyProtocol) async + func updateFilterSet() async func updateHashPrefixes() async } @@ -34,62 +36,39 @@ public struct UpdateManager: UpdateManaging { self.dataManager = dataManager } - private func updateSet( - currentSet: Set, - insert: [T], - delete: [T], - replace: Bool, - saveSet: (Set) -> Void - ) { - var newSet = currentSet - - if replace { - newSet = Set(insert) - } else { - newSet.formUnion(insert) - newSet.subtract(delete) - } - - saveSet(newSet) - } + public func updateData(for key: DataKey) async { + // load currently stored data set + var dataSet = await dataManager.dataSet(for: key) + let oldRevision = dataSet.revision - public func updateFilterSet() async { - let changeSet: APIClient.Response.FiltersChangeSet + // get change set from current revision from API + let changeSet: APIClient.ChangeSetResponse do { - changeSet = try await apiClient.filtersChangeSet(for: .phishing, revision: dataManager.currentRevision) + let request = DataKey.DataSetType.APIRequestType(threatKind: key.threatKind, revision: oldRevision) + changeSet = try await apiClient.load(request) } catch { Logger.updateManager.error("error fetching filter set: \(error)") return } - updateSet( - currentSet: dataManager.filterSet, - insert: changeSet.insert, - delete: changeSet.delete, - replace: changeSet.replace - ) { newSet in - self.dataManager.saveFilterSet(set: newSet) + guard !changeSet.isEmpty || changeSet.revision != dataSet.revision else { + Logger.updateManager.debug("no changes to filter set") + return } - dataManager.saveRevision(changeSet.revision) - Logger.updateManager.debug("filterSet updated to revision \(self.dataManager.currentRevision)") + + // apply changes + dataSet.apply(changeSet) + + // store back + await self.dataManager.store(dataSet, for: key) + Logger.updateManager.debug("\(type(of: key)).\(key.threatKind) updated from rev.\(oldRevision) to rev.\(dataSet.revision)") + } + + public func updateFilterSet() async { + await updateData(for: .filterSet(threatKind: .phishing)) } public func updateHashPrefixes() async { - let changeSet: APIClient.Response.HashPrefixesChangeSet - do { - changeSet = try await apiClient.hashPrefixesChangeSet(for: .phishing, revision: dataManager.currentRevision) - } catch { - Logger.updateManager.error("error fetching hash prefixes: \(error)") - return - } - updateSet( - currentSet: dataManager.hashPrefixes, - insert: changeSet.insert, - delete: changeSet.delete, - replace: changeSet.replace - ) { newSet in - self.dataManager.saveHashPrefixes(set: newSet) - } - dataManager.saveRevision(changeSet.revision) - Logger.updateManager.debug("hashPrefixes updated to revision \(self.dataManager.currentRevision)") + await updateData(for: .hashPrefixes(threatKind: .phishing)) } + } diff --git a/Tests/MaliciousSiteProtectionTests/MaliciousSiteDetectorTests.swift b/Tests/MaliciousSiteProtectionTests/MaliciousSiteDetectorTests.swift index dd633be54..f6b0de23a 100644 --- a/Tests/MaliciousSiteProtectionTests/MaliciousSiteDetectorTests.swift +++ b/Tests/MaliciousSiteProtectionTests/MaliciousSiteDetectorTests.swift @@ -27,26 +27,24 @@ class MaliciousSiteDetectorTests: XCTestCase { private var mockEventMapping: MockEventMapping! private var detector: MaliciousSiteDetector! - override func setUp() { - super.setUp() + override func setUp() async throws { mockAPIClient = MockMaliciousSiteProtectionAPIClient() mockDataManager = MockMaliciousSiteProtectionDataManager() mockEventMapping = MockEventMapping() detector = MaliciousSiteDetector(apiClient: mockAPIClient, dataManager: mockDataManager, eventMapping: mockEventMapping) } - override func tearDown() { + override func tearDown() async throws { mockAPIClient = nil mockDataManager = nil mockEventMapping = nil detector = nil - super.tearDown() } func testIsMaliciousWithLocalFilterHit() async { let filter = Filter(hash: "255a8a793097aeea1f06a19c08cde28db0eb34c660c6e4e7480c9525d034b16d", regex: ".*malicious.*") - mockDataManager.filterSet = Set([filter]) - mockDataManager.hashPrefixes = Set(["255a8a79"]) + await mockDataManager.store(FilterDictionary(revision: 0, items: [filter]), for: .filterSet(threatKind: .phishing)) + await mockDataManager.store(HashPrefixSet(revision: 0, items: ["255a8a79"]), for: .hashPrefixes(threatKind: .phishing)) let url = URL(string: "https://malicious.com/")! @@ -56,8 +54,8 @@ class MaliciousSiteDetectorTests: XCTestCase { } func testIsMaliciousWithApiMatch() async { - mockDataManager.filterSet = Set() - mockDataManager.hashPrefixes = ["a379a6f6"] + await mockDataManager.store(FilterDictionary(revision: 0, items: []), for: .filterSet(threatKind: .phishing)) + await mockDataManager.store(HashPrefixSet(revision: 0, items: ["a379a6f6"]), for: .hashPrefixes(threatKind: .phishing)) let url = URL(string: "https://example.com/mal")! @@ -68,8 +66,8 @@ class MaliciousSiteDetectorTests: XCTestCase { func testIsMaliciousWithHashPrefixMatch() async { let filter = Filter(hash: "notamatch", regex: ".*malicious.*") - mockDataManager.filterSet = [filter] - mockDataManager.hashPrefixes = ["4c64eb24"] // matches safe.com + await mockDataManager.store(FilterDictionary(revision: 0, items: [filter]), for: .filterSet(threatKind: .phishing)) + await mockDataManager.store(HashPrefixSet(revision: 0, items: ["4c64eb24" /* matches safe.com */]), for: .hashPrefixes(threatKind: .phishing)) let url = URL(string: "https://safe.com")! @@ -81,8 +79,8 @@ class MaliciousSiteDetectorTests: XCTestCase { func testIsMaliciousWithFullHashMatch() async { // 4c64eb2468bcd3e113b37167e6b819aeccf550f974a6082ef17fb74ca68e823b let filter = Filter(hash: "4c64eb2468bcd3e113b37167e6b819aeccf550f974a6082ef17fb74ca68e823b", regex: "https://safe.com/maliciousURI") - mockDataManager.filterSet = [filter] - mockDataManager.hashPrefixes = ["4c64eb24"] + await mockDataManager.store(FilterDictionary(revision: 0, items: [filter]), for: .filterSet(threatKind: .phishing)) + await mockDataManager.store(HashPrefixSet(revision: 0, items: ["4c64eb24"]), for: .hashPrefixes(threatKind: .phishing)) let url = URL(string: "https://safe.com")! @@ -93,8 +91,8 @@ class MaliciousSiteDetectorTests: XCTestCase { func testIsMaliciousWithNoHashPrefixMatch() async { let filter = Filter(hash: "testHash", regex: ".*malicious.*") - mockDataManager.filterSet = [filter] - mockDataManager.hashPrefixes = ["testPrefix"] + await mockDataManager.store(FilterDictionary(revision: 0, items: [filter]), for: .filterSet(threatKind: .phishing)) + await mockDataManager.store(HashPrefixSet(revision: 0, items: ["testPrefix"]), for: .hashPrefixes(threatKind: .phishing)) let url = URL(string: "https://safe.com")! diff --git a/Tests/MaliciousSiteProtectionTests/MaliciousSiteProtectionDataManagerTests.swift b/Tests/MaliciousSiteProtectionTests/MaliciousSiteProtectionDataManagerTests.swift index a6763c2f3..91e3c228e 100644 --- a/Tests/MaliciousSiteProtectionTests/MaliciousSiteProtectionDataManagerTests.swift +++ b/Tests/MaliciousSiteProtectionTests/MaliciousSiteProtectionDataManagerTests.swift @@ -27,21 +27,28 @@ class MaliciousSiteProtectionDataManagerTests: XCTestCase { static let hashPrefixesFileName = "phishingHashPrefixes.json" static let filterSetFileName = "phishingFilterSet.json" } - let datasetFiles: [String] = [Constants.hashPrefixesFileName, Constants.filterSetFileName, "revision.txt"] + let datasetFiles: [String] = [Constants.hashPrefixesFileName, Constants.filterSetFileName] var dataManager: MaliciousSiteProtection.DataManager! var fileStore: MaliciousSiteProtection.FileStoring! - override func setUp() { - super.setUp() + override func setUp() async throws { embeddedDataProvider = MockMaliciousSiteProtectionEmbeddedDataProvider() fileStore = MockMaliciousSiteProtectionFileStore() - dataManager = MaliciousSiteProtection.DataManager(embeddedDataProvider: embeddedDataProvider, fileStore: fileStore) + setUpDataManager() } - override func tearDown() { + func setUpDataManager() { + dataManager = MaliciousSiteProtection.DataManager(fileStore: fileStore, embeddedDataProvider: embeddedDataProvider, fileNameProvider: { dataType in + switch dataType { + case .filterSet: Constants.filterSetFileName + case .hashPrefixSet: Constants.hashPrefixesFileName + } + }) + } + + override func tearDown() async throws { embeddedDataProvider = nil dataManager = nil - super.tearDown() } func clearDatasets() { @@ -54,21 +61,21 @@ class MaliciousSiteProtectionDataManagerTests: XCTestCase { func testWhenNoDataSavedThenProviderDataReturned() async { clearDatasets() let expectedFilerSet = Set([Filter(hash: "some", regex: "some")]) + let expectedFilerDict = FilterDictionary(revision: 65, items: expectedFilerSet) let expectedHashPrefix = Set(["sassa"]) - embeddedDataProvider.shouldReturnFilterSet(set: expectedFilerSet) - embeddedDataProvider.shouldReturnHashPrefixes(set: expectedHashPrefix) + embeddedDataProvider.filterSet = expectedFilerSet + embeddedDataProvider.hashPrefixes = expectedHashPrefix - let actualFilterSet = dataManager.filterSet - let actualHashPrefix = dataManager.hashPrefixes + let actualFilterSet = await dataManager.dataSet(for: .filterSet(threatKind: .phishing)) + let actualHashPrefix = await dataManager.dataSet(for: .hashPrefixes(threatKind: .phishing)) - XCTAssertEqual(actualFilterSet, expectedFilerSet) - XCTAssertEqual(actualHashPrefix, expectedHashPrefix) + XCTAssertEqual(actualFilterSet, expectedFilerDict) + XCTAssertEqual(actualHashPrefix.set, expectedHashPrefix) } func testWhenEmbeddedRevisionNewerThanOnDisk_ThenLoadEmbedded() async { let encoder = JSONEncoder() // On Disk Data Setup - fileStore.write(data: "1".utf8data, to: "revision.txt") let onDiskFilterSet = Set([Filter(hash: "other", regex: "other")]) let filterSetData = try! encoder.encode(Array(onDiskFilterSet)) let onDiskHashPrefix = Set(["faffa"]) @@ -79,27 +86,28 @@ class MaliciousSiteProtectionDataManagerTests: XCTestCase { // Embedded Data Setup embeddedDataProvider.embeddedRevision = 5 let embeddedFilterSet = Set([Filter(hash: "some", regex: "some")]) + let embeddedFilterDict = FilterDictionary(revision: 5, items: embeddedFilterSet) let embeddedHashPrefix = Set(["sassa"]) - embeddedDataProvider.shouldReturnFilterSet(set: embeddedFilterSet) - embeddedDataProvider.shouldReturnHashPrefixes(set: embeddedHashPrefix) - - let actualRevision = dataManager.currentRevision - let actualFilterSet = dataManager.filterSet - let actualHashPrefix = dataManager.hashPrefixes - - XCTAssertEqual(actualFilterSet, embeddedFilterSet) - XCTAssertEqual(actualHashPrefix, embeddedHashPrefix) - XCTAssertEqual(actualRevision, 5) + embeddedDataProvider.filterSet = embeddedFilterSet + embeddedDataProvider.hashPrefixes = embeddedHashPrefix + + let actualFilterSet = await dataManager.dataSet(for: .filterSet(threatKind: .phishing)) + let actualHashPrefix = await dataManager.dataSet(for: .hashPrefixes(threatKind: .phishing)) + let actualFilterSetRevision = actualFilterSet.revision + let actualHashPrefixRevision = actualFilterSet.revision + + XCTAssertEqual(actualFilterSet, embeddedFilterDict) + XCTAssertEqual(actualHashPrefix.set, embeddedHashPrefix) + XCTAssertEqual(actualFilterSetRevision, 5) + XCTAssertEqual(actualHashPrefixRevision, 5) } func testWhenEmbeddedRevisionOlderThanOnDisk_ThenDontLoadEmbedded() async { - let encoder = JSONEncoder() // On Disk Data Setup - fileStore.write(data: "6".utf8data, to: "revision.txt") - let onDiskFilterSet = Set([Filter(hash: "other", regex: "other")]) - let filterSetData = try! encoder.encode(Array(onDiskFilterSet)) - let onDiskHashPrefix = Set(["faffa"]) - let hashPrefixData = try! encoder.encode(Array(onDiskHashPrefix)) + let onDiskFilterDict = FilterDictionary(revision: 6, items: [Filter(hash: "other", regex: "other")]) + let filterSetData = try! JSONEncoder().encode(onDiskFilterDict) + let onDiskHashPrefix = HashPrefixSet(revision: 6, items: ["faffa"]) + let hashPrefixData = try! JSONEncoder().encode(onDiskHashPrefix) fileStore.write(data: filterSetData, to: Constants.filterSetFileName) fileStore.write(data: hashPrefixData, to: Constants.hashPrefixesFileName) @@ -107,16 +115,42 @@ class MaliciousSiteProtectionDataManagerTests: XCTestCase { embeddedDataProvider.embeddedRevision = 1 let embeddedFilterSet = Set([Filter(hash: "some", regex: "some")]) let embeddedHashPrefix = Set(["sassa"]) - embeddedDataProvider.shouldReturnFilterSet(set: embeddedFilterSet) - embeddedDataProvider.shouldReturnHashPrefixes(set: embeddedHashPrefix) + embeddedDataProvider.filterSet = embeddedFilterSet + embeddedDataProvider.hashPrefixes = embeddedHashPrefix - let actualRevision = dataManager.currentRevision - let actualFilterSet = dataManager.filterSet - let actualHashPrefix = dataManager.hashPrefixes + let actualFilterSet = await dataManager.dataSet(for: .filterSet(threatKind: .phishing)) + let actualHashPrefix = await dataManager.dataSet(for: .hashPrefixes(threatKind: .phishing)) + let actualFilterSetRevision = actualFilterSet.revision + let actualHashPrefixRevision = actualFilterSet.revision - XCTAssertEqual(actualFilterSet, onDiskFilterSet) + XCTAssertEqual(actualFilterSet, onDiskFilterDict) XCTAssertEqual(actualHashPrefix, onDiskHashPrefix) - XCTAssertEqual(actualRevision, 6) + XCTAssertEqual(actualFilterSetRevision, 6) + XCTAssertEqual(actualHashPrefixRevision, 6) + } + + func testWhenStoredDataIsMalformed_ThenEmbeddedDataIsLoaded() async { + // On Disk Data Setup + fileStore.write(data: "fake".utf8data, to: Constants.filterSetFileName) + fileStore.write(data: "fake".utf8data, to: Constants.hashPrefixesFileName) + + // Embedded Data Setup + embeddedDataProvider.embeddedRevision = 1 + let embeddedFilterSet = Set([Filter(hash: "some", regex: "some")]) + let embeddedFilterDict = FilterDictionary(revision: 1, items: embeddedFilterSet) + let embeddedHashPrefix = Set(["sassa"]) + embeddedDataProvider.filterSet = embeddedFilterSet + embeddedDataProvider.hashPrefixes = embeddedHashPrefix + + let actualFilterSet = await dataManager.dataSet(for: .filterSet(threatKind: .phishing)) + let actualHashPrefix = await dataManager.dataSet(for: .hashPrefixes(threatKind: .phishing)) + let actualFilterSetRevision = actualFilterSet.revision + let actualHashPrefixRevision = actualFilterSet.revision + + XCTAssertEqual(actualFilterSet, embeddedFilterDict) + XCTAssertEqual(actualHashPrefix.set, embeddedHashPrefix) + XCTAssertEqual(actualFilterSetRevision, 1) + XCTAssertEqual(actualHashPrefixRevision, 1) } func testWriteAndLoadData() async { @@ -125,31 +159,31 @@ class MaliciousSiteProtectionDataManagerTests: XCTestCase { let expectedFilterSet = Set([Filter(hash: "dummyhash", regex: "dummyregex")]) let expectedRevision = 65 - dataManager.saveHashPrefixes(set: expectedHashPrefixes) - dataManager.saveFilterSet(set: expectedFilterSet) - dataManager.saveRevision(expectedRevision) - - XCTAssertEqual(dataManager.filterSet, expectedFilterSet) - XCTAssertEqual(dataManager.hashPrefixes, expectedHashPrefixes) - XCTAssertEqual(dataManager.currentRevision, expectedRevision) - - // Test decode JSON data to expected types - let storedHashPrefixesData = fileStore.read(from: Constants.hashPrefixesFileName) - let storedFilterSetData = fileStore.read(from: Constants.filterSetFileName) - let storedRevisionData = fileStore.read(from: "revision.txt") - - let decoder = JSONDecoder() - if let storedHashPrefixes = try? decoder.decode(Set.self, from: storedHashPrefixesData!), - let storedFilterSet = try? decoder.decode(Set.self, from: storedFilterSetData!), - let storedRevisionString = String(data: storedRevisionData!, encoding: .utf8), - let storedRevision = Int(storedRevisionString.trimmingCharacters(in: .whitespacesAndNewlines)) { - - XCTAssertEqual(storedFilterSet, expectedFilterSet) - XCTAssertEqual(storedHashPrefixes, expectedHashPrefixes) - XCTAssertEqual(storedRevision, expectedRevision) - } else { - XCTFail("Failed to decode stored PhishingDetection data") - } + await dataManager.store(HashPrefixSet(revision: expectedRevision, items: expectedHashPrefixes), for: .hashPrefixes(threatKind: .phishing)) + await dataManager.store(FilterDictionary(revision: expectedRevision, items: expectedFilterSet), for: .filterSet(threatKind: .phishing)) + + let actualFilterSet = await dataManager.dataSet(for: .filterSet(threatKind: .phishing)) + let actualHashPrefix = await dataManager.dataSet(for: .hashPrefixes(threatKind: .phishing)) + let actualFilterSetRevision = actualFilterSet.revision + let actualHashPrefixRevision = actualFilterSet.revision + + XCTAssertEqual(actualFilterSet, FilterDictionary(revision: expectedRevision, items: expectedFilterSet)) + XCTAssertEqual(actualHashPrefix.set, expectedHashPrefixes) + XCTAssertEqual(actualFilterSetRevision, 65) + XCTAssertEqual(actualHashPrefixRevision, 65) + + // Test reloading data + setUpDataManager() + + let reloadedFilterSet = await dataManager.dataSet(for: .filterSet(threatKind: .phishing)) + let reloadedHashPrefix = await dataManager.dataSet(for: .hashPrefixes(threatKind: .phishing)) + let reloadedFilterSetRevision = actualFilterSet.revision + let reloadedHashPrefixRevision = actualFilterSet.revision + + XCTAssertEqual(reloadedFilterSet, FilterDictionary(revision: expectedRevision, items: expectedFilterSet)) + XCTAssertEqual(reloadedHashPrefix.set, expectedHashPrefixes) + XCTAssertEqual(reloadedFilterSetRevision, 65) + XCTAssertEqual(reloadedHashPrefixRevision, 65) } func testLazyLoadingDoesNotReturnStaleData() async { @@ -158,53 +192,56 @@ class MaliciousSiteProtectionDataManagerTests: XCTestCase { // Set up initial data let initialFilterSet = Set([Filter(hash: "initial", regex: "initial")]) let initialHashPrefixes = Set(["initialPrefix"]) - embeddedDataProvider.shouldReturnFilterSet(set: initialFilterSet) - embeddedDataProvider.shouldReturnHashPrefixes(set: initialHashPrefixes) + embeddedDataProvider.filterSet = initialFilterSet + embeddedDataProvider.hashPrefixes = initialHashPrefixes // Access the lazy-loaded properties to trigger loading - let loadedFilterSet = dataManager.filterSet - let loadedHashPrefixes = dataManager.hashPrefixes + let loadedFilterSet = await dataManager.dataSet(for: .filterSet(threatKind: .phishing)) + let loadedHashPrefixes = await dataManager.dataSet(for: .hashPrefixes(threatKind: .phishing)) // Validate loaded data matches initial data - XCTAssertEqual(loadedFilterSet, initialFilterSet) - XCTAssertEqual(loadedHashPrefixes, initialHashPrefixes) + XCTAssertEqual(loadedFilterSet, FilterDictionary(revision: 65, items: initialFilterSet)) + XCTAssertEqual(loadedHashPrefixes.set, initialHashPrefixes) // Update in-memory data let updatedFilterSet = Set([Filter(hash: "updated", regex: "updated")]) let updatedHashPrefixes = Set(["updatedPrefix"]) - dataManager.saveFilterSet(set: updatedFilterSet) - dataManager.saveHashPrefixes(set: updatedHashPrefixes) - - // Access lazy-loaded properties again - let reloadedFilterSet = dataManager.filterSet - let reloadedHashPrefixes = dataManager.hashPrefixes - - // Validate reloaded data matches updated data - XCTAssertEqual(reloadedFilterSet, updatedFilterSet) - XCTAssertEqual(reloadedHashPrefixes, updatedHashPrefixes) - - // Validate on-disk data is also updated - let storedFilterSetData = fileStore.read(from: Constants.filterSetFileName) - let storedHashPrefixesData = fileStore.read(from: Constants.hashPrefixesFileName) - - let decoder = JSONDecoder() - if let storedFilterSet = try? decoder.decode(Set.self, from: storedFilterSetData!), - let storedHashPrefixes = try? decoder.decode(Set.self, from: storedHashPrefixesData!) { - - XCTAssertEqual(storedFilterSet, updatedFilterSet) - XCTAssertEqual(storedHashPrefixes, updatedHashPrefixes) - } else { - XCTFail("Failed to decode stored PhishingDetection data after update") - } + await dataManager.store(HashPrefixSet(revision: 1, items: updatedHashPrefixes), for: .hashPrefixes(threatKind: .phishing)) + await dataManager.store(FilterDictionary(revision: 1, items: updatedFilterSet), for: .filterSet(threatKind: .phishing)) + + let actualFilterSet = await dataManager.dataSet(for: .filterSet(threatKind: .phishing)) + let actualHashPrefix = await dataManager.dataSet(for: .hashPrefixes(threatKind: .phishing)) + let actualFilterSetRevision = actualFilterSet.revision + let actualHashPrefixRevision = actualFilterSet.revision + + XCTAssertEqual(actualFilterSet, FilterDictionary(revision: 1, items: updatedFilterSet)) + XCTAssertEqual(actualHashPrefix.set, updatedHashPrefixes) + XCTAssertEqual(actualFilterSetRevision, 1) + XCTAssertEqual(actualHashPrefixRevision, 1) + + // Test reloading data – embedded data should be returned as its revision is greater + setUpDataManager() + + let reloadedFilterSet = await dataManager.dataSet(for: .filterSet(threatKind: .phishing)) + let reloadedHashPrefix = await dataManager.dataSet(for: .hashPrefixes(threatKind: .phishing)) + let reloadedFilterSetRevision = actualFilterSet.revision + let reloadedHashPrefixRevision = actualFilterSet.revision + + XCTAssertEqual(reloadedFilterSet, FilterDictionary(revision: 65, items: initialFilterSet)) + XCTAssertEqual(reloadedHashPrefix.set, initialHashPrefixes) + XCTAssertEqual(reloadedFilterSetRevision, 1) + XCTAssertEqual(reloadedHashPrefixRevision, 1) } } class MockMaliciousSiteProtectionFileStore: MaliciousSiteProtection.FileStoring { + private var data: [String: Data] = [:] - func write(data: Data, to filename: String) { + func write(data: Data, to filename: String) -> Bool { self.data[filename] = data + return true } func read(from filename: String) -> Data? { diff --git a/Tests/MaliciousSiteProtectionTests/MaliciousSiteProtectionEmbeddedDataProviderTest.swift b/Tests/MaliciousSiteProtectionTests/MaliciousSiteProtectionEmbeddedDataProviderTest.swift index 6246e1f93..600352f52 100644 --- a/Tests/MaliciousSiteProtectionTests/MaliciousSiteProtectionEmbeddedDataProviderTest.swift +++ b/Tests/MaliciousSiteProtectionTests/MaliciousSiteProtectionEmbeddedDataProviderTest.swift @@ -21,28 +21,42 @@ import XCTest @testable import MaliciousSiteProtection class MaliciousSiteProtectionEmbeddedDataProviderTest: XCTestCase { - var filterSetURL: URL! - var hashPrefixURL: URL! - var dataProvider: MaliciousSiteProtection.EmbeddedDataProvider! - override func setUp() { - super.setUp() - filterSetURL = Bundle.module.url(forResource: "phishingFilterSet", withExtension: "json")! - hashPrefixURL = Bundle.module.url(forResource: "phishingHashPrefixes", withExtension: "json")! - } + struct TestEmbeddedDataProvider: MaliciousSiteProtection.EmbeddedDataProviding { + func revision(for dataType: MaliciousSiteProtection.DataManager.StoredDataType) -> Int { + 0 + } - override func tearDown() { - filterSetURL = nil - hashPrefixURL = nil - dataProvider = nil - super.tearDown() + func url(for dataType: MaliciousSiteProtection.DataManager.StoredDataType) -> URL { + switch dataType { + case .filterSet(let key): + Bundle.module.url(forResource: "\(key.threatKind)FilterSet", withExtension: "json")! + case .hashPrefixSet(let key): + Bundle.module.url(forResource: "\(key.threatKind)HashPrefixes", withExtension: "json")! + } + } + + func hash(for dataType: MaliciousSiteProtection.DataManager.StoredDataType) -> String { + switch dataType { + case .filterSet(let key): + switch key.threatKind { + case .phishing: + "4fd2868a4f264501ec175ab866504a2a96c8d21a3b5195b405a4a83b51eae504" + } + case .hashPrefixSet(let key): + switch key.threatKind { + case .phishing: + "21b047a9950fcaf86034a6b16181e18815cb8d276386d85c8977ca8c5f8aa05f" + } + } + } } func testDataProviderLoadsJSON() { - dataProvider = .init(revision: 0, filterSetURL: filterSetURL, filterSetDataSHA: "4fd2868a4f264501ec175ab866504a2a96c8d21a3b5195b405a4a83b51eae504", hashPrefixURL: hashPrefixURL, hashPrefixDataSHA: "21b047a9950fcaf86034a6b16181e18815cb8d276386d85c8977ca8c5f8aa05f") + let dataProvider = TestEmbeddedDataProvider() let expectedFilter = Filter(hash: "e4753ddad954dafd4ff4ef67f82b3c1a2db6ef4a51bda43513260170e558bd13", regex: "(?i)^https?\\:\\/\\/privacy-test-pages\\.site(?:\\:(?:80|443))?\\/security\\/badware\\/phishing\\.html$") - XCTAssertTrue(dataProvider.loadEmbeddedFilterSet().contains(expectedFilter)) - XCTAssertTrue(dataProvider.loadEmbeddedHashPrefixes().contains("012db806")) + XCTAssertTrue(dataProvider.loadDataSet(for: .filterSet(threatKind: .phishing)).contains(expectedFilter)) + XCTAssertTrue(dataProvider.loadDataSet(for: .hashPrefixes(threatKind: .phishing)).contains("012db806")) } } diff --git a/Tests/MaliciousSiteProtectionTests/MaliciousSiteProtectionUpdateManagerTests.swift b/Tests/MaliciousSiteProtectionTests/MaliciousSiteProtectionUpdateManagerTests.swift index 53d68dbd5..5eabf6390 100644 --- a/Tests/MaliciousSiteProtectionTests/MaliciousSiteProtectionUpdateManagerTests.swift +++ b/Tests/MaliciousSiteProtectionTests/MaliciousSiteProtectionUpdateManagerTests.swift @@ -27,40 +27,36 @@ class MaliciousSiteProtectionUpdateManagerTests: XCTestCase { var apiClient: MaliciousSiteProtection.APIClientProtocol! override func setUp() async throws { - try await super.setUp() apiClient = MockMaliciousSiteProtectionAPIClient() dataManager = MockMaliciousSiteProtectionDataManager() updateManager = MaliciousSiteProtection.UpdateManager(apiClient: apiClient, dataManager: dataManager) - dataManager.saveRevision(0) - await updateManager.updateFilterSet() - await updateManager.updateHashPrefixes() } override func tearDown() { updateManager = nil dataManager = nil apiClient = nil - super.tearDown() } func testUpdateHashPrefixes() async { await updateManager.updateHashPrefixes() - XCTAssertFalse(dataManager.hashPrefixes.isEmpty, "Hash prefixes should not be empty after update.") - XCTAssertEqual(dataManager.hashPrefixes, [ + let dataSet = await dataManager.dataSet(for: .hashPrefixes(threatKind: .phishing)) + XCTAssertEqual(dataSet, HashPrefixSet(revision: 1, items: [ "aa00bb11", "bb00cc11", "cc00dd11", "dd00ee11", "a379a6f6" - ]) + ])) } func testUpdateFilterSet() async { await updateManager.updateFilterSet() - XCTAssertEqual(dataManager.filterSet, [ + let dataSet = await dataManager.dataSet(for: .filterSet(threatKind: .phishing)) + XCTAssertEqual(dataSet, FilterDictionary(revision: 1, items: [ Filter(hash: "testhash1", regex: ".*example.*"), Filter(hash: "testhash2", regex: ".*test.*") - ]) + ])) } func testRevision1AddsAndDeletesData() async { @@ -75,19 +71,27 @@ class MaliciousSiteProtectionUpdateManagerTests: XCTestCase { "93e2435e" ] - // Save revision and update the filter set and hash prefixes - dataManager.saveRevision(1) + // revision 0 -> 1 + await updateManager.updateFilterSet() + await updateManager.updateHashPrefixes() + + // revision 1 -> 2 await updateManager.updateFilterSet() await updateManager.updateHashPrefixes() - XCTAssertEqual(dataManager.filterSet, expectedFilterSet, "Filter set should match the expected set after update.") - XCTAssertEqual(dataManager.hashPrefixes, expectedHashPrefixes, "Hash prefixes should match the expected set after update.") + let hashPrefixes = await dataManager.dataSet(for: .hashPrefixes(threatKind: .phishing)) + let filterSet = await dataManager.dataSet(for: .filterSet(threatKind: .phishing)) + + XCTAssertEqual(hashPrefixes, HashPrefixSet(revision: 2, items: expectedHashPrefixes), "Hash prefixes should match the expected set after update.") + XCTAssertEqual(filterSet, FilterDictionary(revision: 2, items: expectedFilterSet), "Filter set should match the expected set after update.") } func testRevision2AddsAndDeletesData() async { let expectedFilterSet: Set = [ Filter(hash: "testhash4", regex: ".*test.*"), - Filter(hash: "testhash1", regex: ".*example.*") + Filter(hash: "testhash2", regex: ".*test1.*"), + Filter(hash: "testhash1", regex: ".*example.*"), + Filter(hash: "testhash3", regex: ".*test3.*"), ] let expectedHashPrefixes: Set = [ "aa00bb11", @@ -98,46 +102,99 @@ class MaliciousSiteProtectionUpdateManagerTests: XCTestCase { ] // Save revision and update the filter set and hash prefixes - dataManager.saveRevision(2) + await dataManager.store(FilterDictionary(revision: 2, items: [ + Filter(hash: "testhash1", regex: ".*example.*"), + Filter(hash: "testhash2", regex: ".*test.*"), + Filter(hash: "testhash2", regex: ".*test1.*"), + Filter(hash: "testhash3", regex: ".*test3.*"), + ]), for: .filterSet(threatKind: .phishing)) + await dataManager.store(HashPrefixSet(revision: 2, items: [ + "aa00bb11", + "bb00cc11", + "cc00dd11", + "dd00ee11", + "a379a6f6" + ]), for: .hashPrefixes(threatKind: .phishing)) + await updateManager.updateFilterSet() await updateManager.updateHashPrefixes() - XCTAssertEqual(dataManager.filterSet, expectedFilterSet, "Filter set should match the expected set after update.") - XCTAssertEqual(dataManager.hashPrefixes, expectedHashPrefixes, "Hash prefixes should match the expected set after update.") + let hashPrefixes = await dataManager.dataSet(for: .hashPrefixes(threatKind: .phishing)) + let filterSet = await dataManager.dataSet(for: .filterSet(threatKind: .phishing)) + + XCTAssertEqual(hashPrefixes, HashPrefixSet(revision: 3, items: expectedHashPrefixes), "Hash prefixes should match the expected set after update.") + XCTAssertEqual(filterSet, FilterDictionary(revision: 3, items: expectedFilterSet), "Filter set should match the expected set after update.") } func testRevision3AddsAndDeletesNothing() async { - let expectedFilterSet = dataManager.filterSet - let expectedHashPrefixes = dataManager.hashPrefixes + let expectedFilterSet: Set = [] + let expectedHashPrefixes: Set = [] // Save revision and update the filter set and hash prefixes - dataManager.saveRevision(3) + await dataManager.store(FilterDictionary(revision: 3, items: []), for: .filterSet(threatKind: .phishing)) + await dataManager.store(HashPrefixSet(revision: 3, items: []), for: .hashPrefixes(threatKind: .phishing)) + await updateManager.updateFilterSet() await updateManager.updateHashPrefixes() - XCTAssertEqual(dataManager.filterSet, expectedFilterSet, "Filter set should match the expected set after update.") - XCTAssertEqual(dataManager.hashPrefixes, expectedHashPrefixes, "Hash prefixes should match the expected set after update.") + let hashPrefixes = await dataManager.dataSet(for: .hashPrefixes(threatKind: .phishing)) + let filterSet = await dataManager.dataSet(for: .filterSet(threatKind: .phishing)) + + XCTAssertEqual(hashPrefixes, HashPrefixSet(revision: 3, items: expectedHashPrefixes), "Hash prefixes should match the expected set after update.") + XCTAssertEqual(filterSet, FilterDictionary(revision: 3, items: expectedFilterSet), "Filter set should match the expected set after update.") } func testRevision4AddsAndDeletesData() async { let expectedFilterSet: Set = [ - Filter(hash: "testhash2", regex: ".*test.*"), - Filter(hash: "testhash1", regex: ".*example.*"), Filter(hash: "testhash5", regex: ".*test.*") ] let expectedHashPrefixes: Set = [ "a379a6f6", + ] + + // Save revision and update the filter set and hash prefixes + await dataManager.store(FilterDictionary(revision: 4, items: []), for: .filterSet(threatKind: .phishing)) + await dataManager.store(HashPrefixSet(revision: 4, items: []), for: .hashPrefixes(threatKind: .phishing)) + + await updateManager.updateFilterSet() + await updateManager.updateHashPrefixes() + + let hashPrefixes = await dataManager.dataSet(for: .hashPrefixes(threatKind: .phishing)) + let filterSet = await dataManager.dataSet(for: .filterSet(threatKind: .phishing)) + + XCTAssertEqual(hashPrefixes, HashPrefixSet(revision: 5, items: expectedHashPrefixes), "Hash prefixes should match the expected set after update.") + XCTAssertEqual(filterSet, FilterDictionary(revision: 5, items: expectedFilterSet), "Filter set should match the expected set after update.") + } + + func testRevision5replacesData() async { + let expectedFilterSet: Set = [ + Filter(hash: "testhash6", regex: ".*test6.*") + ] + let expectedHashPrefixes: Set = [ + "aa55aa55" + ] + + // Save revision and update the filter set and hash prefixes + await dataManager.store(FilterDictionary(revision: 5, items: [ + Filter(hash: "testhash2", regex: ".*test.*"), + Filter(hash: "testhash1", regex: ".*example.*"), + Filter(hash: "testhash5", regex: ".*test.*") + ]), for: .filterSet(threatKind: .phishing)) + await dataManager.store(HashPrefixSet(revision: 5, items: [ + "a379a6f6", "dd00ee11", "cc00dd11", "bb00cc11" - ] + ]), for: .hashPrefixes(threatKind: .phishing)) - // Save revision and update the filter set and hash prefixes - dataManager.saveRevision(4) await updateManager.updateFilterSet() await updateManager.updateHashPrefixes() - XCTAssertEqual(dataManager.filterSet, expectedFilterSet, "Filter set should match the expected set after update.") - XCTAssertEqual(dataManager.hashPrefixes, expectedHashPrefixes, "Hash prefixes should match the expected set after update.") + let hashPrefixes = await dataManager.dataSet(for: .hashPrefixes(threatKind: .phishing)) + let filterSet = await dataManager.dataSet(for: .filterSet(threatKind: .phishing)) + + XCTAssertEqual(hashPrefixes, HashPrefixSet(revision: 6, items: expectedHashPrefixes), "Hash prefixes should match the expected set after update.") + XCTAssertEqual(filterSet, FilterDictionary(revision: 6, items: expectedFilterSet), "Filter set should match the expected set after update.") } + } diff --git a/Tests/MaliciousSiteProtectionTests/Mocks/MockMaliciousSiteProtectionAPIClient.swift b/Tests/MaliciousSiteProtectionTests/Mocks/MockMaliciousSiteProtectionAPIClient.swift index 24d1c203b..874e4fdbd 100644 --- a/Tests/MaliciousSiteProtectionTests/Mocks/MockMaliciousSiteProtectionAPIClient.swift +++ b/Tests/MaliciousSiteProtectionTests/Mocks/MockMaliciousSiteProtectionAPIClient.swift @@ -27,22 +27,27 @@ public class MockMaliciousSiteProtectionAPIClient: MaliciousSiteProtection.APICl 0: .init(insert: [ Filter(hash: "testhash1", regex: ".*example.*"), Filter(hash: "testhash2", regex: ".*test.*") - ], delete: [], revision: 0, replace: true), + ], delete: [], revision: 1, replace: false), 1: .init(insert: [ Filter(hash: "testhash3", regex: ".*test.*") ], delete: [ Filter(hash: "testhash1", regex: ".*example.*"), - ], revision: 1, replace: false), + ], revision: 2, replace: false), 2: .init(insert: [ Filter(hash: "testhash4", regex: ".*test.*") ], delete: [ Filter(hash: "testhash2", regex: ".*test.*"), - ], revision: 2, replace: false), + ], revision: 3, replace: false), 4: .init(insert: [ Filter(hash: "testhash5", regex: ".*test.*") ], delete: [ Filter(hash: "testhash3", regex: ".*test.*"), - ], revision: 4, replace: false) + ], revision: 5, replace: false), + 5: .init(insert: [ + Filter(hash: "testhash6", regex: ".*test6.*") + ], delete: [ + Filter(hash: "testhash3", regex: ".*test.*"), + ], revision: 6, replace: true), ] private var hashPrefixRevisions: [Int: APIClient.Response.HashPrefixesChangeSet] = [ @@ -52,17 +57,20 @@ public class MockMaliciousSiteProtectionAPIClient: MaliciousSiteProtection.APICl "cc00dd11", "dd00ee11", "a379a6f6" - ], delete: [], revision: 0, replace: true), + ], delete: [], revision: 1, replace: false), 1: .init(insert: ["93e2435e"], delete: [ "cc00dd11", "dd00ee11", - ], revision: 1, replace: false), + ], revision: 2, replace: false), 2: .init(insert: ["c0be0d0a6"], delete: [ "bb00cc11", - ], revision: 2, replace: false), + ], revision: 3, replace: false), 4: .init(insert: ["a379a6f6"], delete: [ "aa00bb11", - ], revision: 4, replace: false) + ], revision: 5, replace: false), + 5: .init(insert: ["aa55aa55"], delete: [ + "ffgghhzz", + ], revision: 6, replace: true), ] public func load(_ requestConfig: Request) async throws -> Request.ResponseType where Request: APIRequestProtocol { @@ -87,7 +95,7 @@ public class MockMaliciousSiteProtectionAPIClient: MaliciousSiteProtection.APICl func _matches(forHashPrefix hashPrefix: String) -> APIClient.Response.Matches { .init(matches: [ - Match(hostname: "example.com", url: "https://example.com/mal", regex: ".", hash: "a379a6f6eeafb9a55e378c118034e2751e682fab9f2d30ab13d2125586ce1947", category: nil), + Match(hostname: "example.com", url: "https://example.com/mal", regex: ".*", hash: "a379a6f6eeafb9a55e378c118034e2751e682fab9f2d30ab13d2125586ce1947", category: nil), Match(hostname: "test.com", url: "https://test.com/mal", regex: ".*test.*", hash: "aa00bb11aa00cc11bb00cc11", category: nil) ]) } diff --git a/Tests/MaliciousSiteProtectionTests/Mocks/MockMaliciousSiteProtectionDataManager.swift b/Tests/MaliciousSiteProtectionTests/Mocks/MockMaliciousSiteProtectionDataManager.swift index 64d82bbea..963445abe 100644 --- a/Tests/MaliciousSiteProtectionTests/Mocks/MockMaliciousSiteProtectionDataManager.swift +++ b/Tests/MaliciousSiteProtectionTests/Mocks/MockMaliciousSiteProtectionDataManager.swift @@ -20,25 +20,14 @@ import Foundation import MaliciousSiteProtection public class MockMaliciousSiteProtectionDataManager: MaliciousSiteProtection.DataManaging { - public var filterSet: Set - public var hashPrefixes: Set - public var currentRevision: Int + var store = [MaliciousSiteProtection.DataManager.StoredDataType: Any]() - public init() { - filterSet = Set() - hashPrefixes = Set() - currentRevision = 0 + public func dataSet(for key: DataKey) async -> DataKey.DataSetType where DataKey : MaliciousSiteProtection.MaliciousSiteDataKeyProtocol { + store[key.dataType] as? DataKey.DataSetType ?? .init(revision: 0, items: []) } - public func saveFilterSet(set: Set) { - filterSet = set + public func store(_ dataSet: DataKey.DataSetType, for key: DataKey) async where DataKey : MaliciousSiteProtection.MaliciousSiteDataKeyProtocol { + store[key.dataType] = dataSet } - public func saveHashPrefixes(set: Set) { - hashPrefixes = set - } - - public func saveRevision(_ revision: Int) { - currentRevision = revision - } } diff --git a/Tests/MaliciousSiteProtectionTests/Mocks/MockMaliciousSiteProtectionEmbeddedDataProvider.swift b/Tests/MaliciousSiteProtectionTests/Mocks/MockMaliciousSiteProtectionEmbeddedDataProvider.swift index 9bb44bbe2..2cf9d25da 100644 --- a/Tests/MaliciousSiteProtectionTests/Mocks/MockMaliciousSiteProtectionEmbeddedDataProvider.swift +++ b/Tests/MaliciousSiteProtectionTests/Mocks/MockMaliciousSiteProtectionEmbeddedDataProvider.swift @@ -23,25 +23,30 @@ public class MockMaliciousSiteProtectionEmbeddedDataProvider: MaliciousSiteProte public var embeddedRevision: Int = 65 var loadHashPrefixesCalled: Bool = false var loadFilterSetCalled: Bool = true - var hashPrefixes: Set = ["aabb"] - var filterSet: Set = [Filter(hash: "dummyhash", regex: "dummyregex")] + var hashPrefixes = Set(["aabb"]) + var filterSet = Set([Filter(hash: "dummyhash", regex: "dummyregex")]) - public func shouldReturnFilterSet(set: Set) { - self.filterSet = set + public func revision(for detectionKind: MaliciousSiteProtection.DataManager.StoredDataType) -> Int { + embeddedRevision } - - public func shouldReturnHashPrefixes(set: Set) { - self.hashPrefixes = set + + public func url(for detectionKind: MaliciousSiteProtection.DataManager.StoredDataType) -> URL { + URL.empty } - - public func loadEmbeddedFilterSet() -> Set { - self.loadHashPrefixesCalled = true - return self.filterSet + + public func hash(for detectionKind: MaliciousSiteProtection.DataManager.StoredDataType) -> String { + "" } - public func loadEmbeddedHashPrefixes() -> Set { - self.loadFilterSetCalled = true - return self.hashPrefixes + public func loadDataSet(for key: DataKey) -> DataKey.EmbeddedDataSetType where DataKey : MaliciousSiteDataKeyProtocol { + switch key.dataType { + case .filterSet: + self.loadFilterSetCalled = true + return Array(filterSet) as! DataKey.EmbeddedDataSetType + case .hashPrefixSet: + self.loadHashPrefixesCalled = true + return Array(hashPrefixes) as! DataKey.EmbeddedDataSetType + } } } diff --git a/Tests/MaliciousSiteProtectionTests/Mocks/MockPhishingDetectionUpdateManager.swift b/Tests/MaliciousSiteProtectionTests/Mocks/MockPhishingDetectionUpdateManager.swift index 3eb67c06b..c7fac96b4 100644 --- a/Tests/MaliciousSiteProtectionTests/Mocks/MockPhishingDetectionUpdateManager.swift +++ b/Tests/MaliciousSiteProtectionTests/Mocks/MockPhishingDetectionUpdateManager.swift @@ -20,10 +20,18 @@ import Foundation import MaliciousSiteProtection public class MockPhishingDetectionUpdateManager: MaliciousSiteProtection.UpdateManaging { + var didUpdateFilterSet = false var didUpdateHashPrefixes = false var completionHandler: (() -> Void)? + public func updateData(for key: some MaliciousSiteProtection.MaliciousSiteDataKeyProtocol) async { + switch key.dataType { + case .filterSet: await updateFilterSet() + case .hashPrefixSet: await updateHashPrefixes() + } + } + public func updateFilterSet() async { didUpdateFilterSet = true checkCompletion() From 1509f0cf7c48800bf79bc5f6a5dfa3f7f3f7b696 Mon Sep 17 00:00:00 2001 From: Alexey Martemyanov Date: Tue, 26 Nov 2024 19:42:10 +0600 Subject: [PATCH 05/15] add missing public and var --- Sources/MaliciousSiteProtection/Model/StoredData.swift | 7 +++++++ Sources/MaliciousSiteProtection/Services/FileStore.swift | 6 +++--- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/Sources/MaliciousSiteProtection/Model/StoredData.swift b/Sources/MaliciousSiteProtection/Model/StoredData.swift index 041fa7390..7164bdca7 100644 --- a/Sources/MaliciousSiteProtection/Model/StoredData.swift +++ b/Sources/MaliciousSiteProtection/Model/StoredData.swift @@ -49,6 +49,13 @@ public extension DataManager { } } + public var threatKind: ThreatKind { + switch self { + case .hashPrefixSet(let key): key.threatKind + case .filterSet(let key): key.threatKind + } + } + public static var allCases: [DataManager.StoredDataType] { ThreatKind.allCases.map { threatKind in Kind.allCases.map { dataKind in diff --git a/Sources/MaliciousSiteProtection/Services/FileStore.swift b/Sources/MaliciousSiteProtection/Services/FileStore.swift index 6cc03b7e7..06418e6a2 100644 --- a/Sources/MaliciousSiteProtection/Services/FileStore.swift +++ b/Sources/MaliciousSiteProtection/Services/FileStore.swift @@ -24,10 +24,10 @@ public protocol FileStoring { func read(from filename: String) -> Data? } -struct FileStore: FileStoring, CustomDebugStringConvertible { +public struct FileStore: FileStoring, CustomDebugStringConvertible { private let dataStoreURL: URL - init(dataStoreURL: URL) { + public init(dataStoreURL: URL) { self.dataStoreURL = dataStoreURL createDirectoryIfNeeded() } @@ -61,7 +61,7 @@ struct FileStore: FileStoring, CustomDebugStringConvertible { } } - var debugDescription: String { + public var debugDescription: String { return "<\(type(of: self)) - \"\(dataStoreURL.path)\">" } } From 0f3e7ba972a5b4cc580746ecdb339f0bc8d1933f Mon Sep 17 00:00:00 2001 From: Alexey Martemyanov Date: Wed, 27 Nov 2024 18:28:40 +0600 Subject: [PATCH 06/15] fix API hash mathing --- Sources/Common/Extensions/HashExtension.swift | 9 ++++++-- .../MaliciousSiteDetector.swift | 23 ++++++++----------- .../Extensions/StringExtensionTests.swift | 11 +++++++++ 3 files changed, 27 insertions(+), 16 deletions(-) diff --git a/Sources/Common/Extensions/HashExtension.swift b/Sources/Common/Extensions/HashExtension.swift index b6752cf57..13095cf63 100644 --- a/Sources/Common/Extensions/HashExtension.swift +++ b/Sources/Common/Extensions/HashExtension.swift @@ -42,8 +42,13 @@ extension Data { extension String { public var sha1: String { - let dataBytes = data(using: .utf8)! - return dataBytes.sha1 + let result = utf8data.sha1 + return result + } + + public var sha256: String { + let result = utf8data.sha256 + return result } } diff --git a/Sources/MaliciousSiteProtection/MaliciousSiteDetector.swift b/Sources/MaliciousSiteProtection/MaliciousSiteDetector.swift index 206524b9a..82704fc34 100644 --- a/Sources/MaliciousSiteProtection/MaliciousSiteDetector.swift +++ b/Sources/MaliciousSiteProtection/MaliciousSiteDetector.swift @@ -46,23 +46,17 @@ public final class MaliciousSiteDetector: MaliciousSiteDetecting { self.eventMapping = eventMapping } - private func generateHashPrefix(for canonicalHost: String, length: Int) -> String { - let hostnameHash = SHA256.hash(data: Data(canonicalHost.utf8)).map { String(format: "%02hhx", $0) }.joined() - return String(hostnameHash.prefix(length)) - } - - private func checkLocalFilters(canonicalHost: String, canonicalUrl: URL, for threatKind: ThreatKind) async -> Bool { + private func checkLocalFilters(hostHash: String, canonicalUrl: URL, for threatKind: ThreatKind) async -> Bool { let filterSet = await dataManager.dataSet(for: .filterSet(threatKind: threatKind)) - let hostnameHash = generateHashPrefix(for: canonicalHost, length: Int.max) - let matchesLocalFilters = filterSet[hostnameHash]?.contains(where: { regex in + let matchesLocalFilters = filterSet[hostHash]?.contains(where: { regex in canonicalUrl.absoluteString.matches(pattern: regex) }) ?? false return matchesLocalFilters } - private func checkApiMatches(canonicalHost: String, canonicalUrl: URL) async -> Match? { - let hashPrefixParam = generateHashPrefix(for: canonicalHost, length: Constants.hashPrefixParamLength) + private func checkApiMatches(hostHash: String, canonicalUrl: URL) async -> Match? { + let hashPrefixParam = String(hostHash.prefix(Constants.hashPrefixParamLength)) let matches: [Match] do { matches = try await apiClient.matches(forHashPrefix: hashPrefixParam).matches @@ -72,7 +66,7 @@ public final class MaliciousSiteDetector: MaliciousSiteDetecting { } if let match = matches.first(where: { match in - canonicalUrl.absoluteString.matches(pattern: match.regex) + match.hash == hostHash && canonicalUrl.absoluteString.matches(pattern: match.regex) }) { return match } @@ -86,20 +80,21 @@ public final class MaliciousSiteDetector: MaliciousSiteDetecting { guard let canonicalHost = url.canonicalHost(), let canonicalUrl = url.canonicalURL() else { return .none } - let hashPrefix = generateHashPrefix(for: canonicalHost, length: Constants.hashPrefixStoreLength) + let hostHash = canonicalHost.sha256 + let hashPrefix = String(hostHash.prefix(Constants.hashPrefixStoreLength)) for threatKind in ThreatKind.allCases /* phishing, malware.. */ { let hashPrefixes = await dataManager.dataSet(for: .hashPrefixes(threatKind: threatKind)) guard hashPrefixes.contains(hashPrefix) else { continue } // Check local filterSet first - if await checkLocalFilters(canonicalHost: canonicalHost, canonicalUrl: canonicalUrl, for: threatKind) { + if await checkLocalFilters(hostHash: hostHash, canonicalUrl: canonicalUrl, for: threatKind) { eventMapping.fire(.errorPageShown(clientSideHit: true)) return threatKind } // If nothing found, hit the API to get matches - let match = await checkApiMatches(canonicalHost: canonicalHost, canonicalUrl: canonicalUrl) + let match = await checkApiMatches(hostHash: hostHash, canonicalUrl: canonicalUrl) if let match { eventMapping.fire(.errorPageShown(clientSideHit: false)) return match.category.map(ThreatKind.init) ?? threatKind diff --git a/Tests/CommonTests/Extensions/StringExtensionTests.swift b/Tests/CommonTests/Extensions/StringExtensionTests.swift index bcf415895..65abb79c8 100644 --- a/Tests/CommonTests/Extensions/StringExtensionTests.swift +++ b/Tests/CommonTests/Extensions/StringExtensionTests.swift @@ -16,8 +16,10 @@ // limitations under the License. // +import CryptoKit import Foundation import XCTest + @testable import Common final class StringExtensionTests: XCTestCase { @@ -370,4 +372,13 @@ final class StringExtensionTests: XCTestCase { } } + func testSha256() { + let string = "Hello, World! This is a test string." + let hash = string.sha256 + let expected = "3c2b805ab0038afb0629e1d598ae73e0caabb69de03e96762977d34e8ba428bf" + let expectedSHA256 = SHA256.hash(data: Data(string.utf8)).map { String(format: "%02hhx", $0) }.joined() + XCTAssertEqual(hash, expected) + XCTAssertEqual(hash, expectedSHA256) + } + } From e6d3d571a5ff8c30fafd6ecd71b246ca1c210b3b Mon Sep 17 00:00:00 2001 From: Alexey Martemyanov Date: Wed, 27 Nov 2024 18:35:25 +0600 Subject: [PATCH 07/15] fix linter issues --- Sources/MaliciousSiteProtection/Model/FilterDictionary.swift | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/Sources/MaliciousSiteProtection/Model/FilterDictionary.swift b/Sources/MaliciousSiteProtection/Model/FilterDictionary.swift index 79cb6632e..d9da81147 100644 --- a/Sources/MaliciousSiteProtection/Model/FilterDictionary.swift +++ b/Sources/MaliciousSiteProtection/Model/FilterDictionary.swift @@ -38,7 +38,7 @@ public struct FilterDictionary: Codable, Equatable { /// ``` public var filters: [String: Set] - init(revision: Int, filters: [String: Set]) { + public init(revision: Int, filters: [String: Set]) { self.filters = filters self.revision = revision } @@ -53,7 +53,7 @@ public struct FilterDictionary: Codable, Equatable { withUnsafeMutablePointer(to: &filters[filter.hash]) { item in item.pointee?.remove(filter.regex) if item.pointee?.isEmpty == true { - item.pointee = nil // TODO: Validate deallocation + item.pointee = nil } } } From 47d9dae5de11ae4a44d9f741862375b1cecbb7fb Mon Sep 17 00:00:00 2001 From: Alexey Martemyanov Date: Thu, 28 Nov 2024 16:29:20 +0600 Subject: [PATCH 08/15] fix typo --- .../MaliciousSiteProtectionDataManagerTests.swift | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/Tests/MaliciousSiteProtectionTests/MaliciousSiteProtectionDataManagerTests.swift b/Tests/MaliciousSiteProtectionTests/MaliciousSiteProtectionDataManagerTests.swift index 91e3c228e..5164f78d3 100644 --- a/Tests/MaliciousSiteProtectionTests/MaliciousSiteProtectionDataManagerTests.swift +++ b/Tests/MaliciousSiteProtectionTests/MaliciousSiteProtectionDataManagerTests.swift @@ -60,16 +60,16 @@ class MaliciousSiteProtectionDataManagerTests: XCTestCase { func testWhenNoDataSavedThenProviderDataReturned() async { clearDatasets() - let expectedFilerSet = Set([Filter(hash: "some", regex: "some")]) - let expectedFilerDict = FilterDictionary(revision: 65, items: expectedFilerSet) + let expectedFilterSet = Set([Filter(hash: "some", regex: "some")]) + let expectedFilterDict = FilterDictionary(revision: 65, items: expectedFilterSet) let expectedHashPrefix = Set(["sassa"]) - embeddedDataProvider.filterSet = expectedFilerSet + embeddedDataProvider.filterSet = expectedFilterSet embeddedDataProvider.hashPrefixes = expectedHashPrefix let actualFilterSet = await dataManager.dataSet(for: .filterSet(threatKind: .phishing)) let actualHashPrefix = await dataManager.dataSet(for: .hashPrefixes(threatKind: .phishing)) - XCTAssertEqual(actualFilterSet, expectedFilerDict) + XCTAssertEqual(actualFilterSet, expectedFilterDict) XCTAssertEqual(actualHashPrefix.set, expectedHashPrefix) } From 6e1c710b9a0efeaafb96490a03a13ec07e2e8aaa Mon Sep 17 00:00:00 2001 From: Alexey Martemyanov Date: Thu, 28 Nov 2024 16:50:09 +0600 Subject: [PATCH 09/15] add a comment --- .../Model/FilterDictionary.swift | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/Sources/MaliciousSiteProtection/Model/FilterDictionary.swift b/Sources/MaliciousSiteProtection/Model/FilterDictionary.swift index d9da81147..2edd25b31 100644 --- a/Sources/MaliciousSiteProtection/Model/FilterDictionary.swift +++ b/Sources/MaliciousSiteProtection/Model/FilterDictionary.swift @@ -50,6 +50,21 @@ public struct FilterDictionary: Codable, Equatable { public mutating func subtract(_ itemsToDelete: Seq) where Seq.Element == Filter { for filter in itemsToDelete { + // Remove the filter from the Set stored in the Dictionary by hash used as a key. + // If the Set becomes empty – remove the Set value from the Dictionary. + // + // The following code is equivalent to this one but without the Set value being copied + // or key being searched multiple times: + /* + if var filterSet = self.filters[filter.hash] { + filterSet.remove(filter.regex) + if filterSet.isEmpty { + self.filters[filter.hash] = nil + } else { + self.filters[filter.hash] = filterSet + } + } + */ withUnsafeMutablePointer(to: &filters[filter.hash]) { item in item.pointee?.remove(filter.regex) if item.pointee?.isEmpty == true { From a2d11c00c4c6277a458c0802b2f15712495c0603 Mon Sep 17 00:00:00 2001 From: Alexey Martemyanov Date: Thu, 28 Nov 2024 16:56:35 +0600 Subject: [PATCH 10/15] Improve Type naming --- .../Model/HashPrefixSet.swift | 1 + ...entallyUpdatableMaliciousSiteDataSet.swift | 11 +++++----- .../Model/LoadableFromEmbeddedData.swift | 8 ++++---- .../Model/StoredData.swift | 20 +++++++++---------- .../Services/DataManager.swift | 16 +++++++-------- .../Services/EmbeddedDataProvider.swift | 6 +++--- .../Services/UpdateManager.swift | 8 ++++---- ...ckMaliciousSiteProtectionDataManager.swift | 6 +++--- ...usSiteProtectionEmbeddedDataProvider.swift | 6 +++--- .../MockPhishingDetectionUpdateManager.swift | 2 +- 10 files changed, 43 insertions(+), 41 deletions(-) diff --git a/Sources/MaliciousSiteProtection/Model/HashPrefixSet.swift b/Sources/MaliciousSiteProtection/Model/HashPrefixSet.swift index 06a5fc7fa..9300e58b1 100644 --- a/Sources/MaliciousSiteProtection/Model/HashPrefixSet.swift +++ b/Sources/MaliciousSiteProtection/Model/HashPrefixSet.swift @@ -18,6 +18,7 @@ import Foundation +/// Structure storing a Set of hash prefixes ["6fe1e7c8","1d760415",...] and a revision of the set. public struct HashPrefixSet: Codable, Equatable { public var revision: Int diff --git a/Sources/MaliciousSiteProtection/Model/IncrementallyUpdatableMaliciousSiteDataSet.swift b/Sources/MaliciousSiteProtection/Model/IncrementallyUpdatableMaliciousSiteDataSet.swift index 773659cdb..3d65553c4 100644 --- a/Sources/MaliciousSiteProtection/Model/IncrementallyUpdatableMaliciousSiteDataSet.swift +++ b/Sources/MaliciousSiteProtection/Model/IncrementallyUpdatableMaliciousSiteDataSet.swift @@ -19,7 +19,8 @@ public protocol IncrementallyUpdatableMaliciousSiteDataSet: Codable { /// Set Element Type (Hash Prefix or Filter) associatedtype Element: Codable, Hashable - associatedtype APIRequestType: MaliciousSiteDataChangeSetAPIRequestProtocol, APIRequestProtocol where APIRequestType.ResponseType == APIClient.ChangeSetResponse + /// API Request type used to fetch updates for the data set + associatedtype APIRequest: MaliciousSiteDataChangeSetAPIRequestProtocol, APIRequestProtocol where APIRequest.ResponseType == APIClient.ChangeSetResponse var revision: Int { get set } @@ -46,16 +47,16 @@ extension IncrementallyUpdatableMaliciousSiteDataSet { extension HashPrefixSet: IncrementallyUpdatableMaliciousSiteDataSet { public typealias Element = String - public typealias APIRequestType = APIClient.Request.HashPrefixes + public typealias APIRequest = APIClient.Request.HashPrefixes - public static func apiRequest(for threatKind: ThreatKind, revision: Int) -> APIRequestType { + public static func apiRequest(for threatKind: ThreatKind, revision: Int) -> APIRequest { .hashPrefixes(threatKind: threatKind, revision: revision) } } extension FilterDictionary: IncrementallyUpdatableMaliciousSiteDataSet { public typealias Element = Filter - public typealias APIRequestType = APIClient.Request.FilterSet + public typealias APIRequest = APIClient.Request.FilterSet public init(revision: Int, items: some Sequence) { let filtersDictionary = items.reduce(into: [String: Set]()) { result, filter in @@ -64,7 +65,7 @@ extension FilterDictionary: IncrementallyUpdatableMaliciousSiteDataSet { self.init(revision: revision, filters: filtersDictionary) } - public static func apiRequest(for threatKind: ThreatKind, revision: Int) -> APIRequestType { + public static func apiRequest(for threatKind: ThreatKind, revision: Int) -> APIRequest { .filterSet(threatKind: threatKind, revision: revision) } } diff --git a/Sources/MaliciousSiteProtection/Model/LoadableFromEmbeddedData.swift b/Sources/MaliciousSiteProtection/Model/LoadableFromEmbeddedData.swift index d6eddce09..be67cb6fc 100644 --- a/Sources/MaliciousSiteProtection/Model/LoadableFromEmbeddedData.swift +++ b/Sources/MaliciousSiteProtection/Model/LoadableFromEmbeddedData.swift @@ -16,19 +16,19 @@ // limitations under the License. // -public protocol LoadableFromEmbeddedData { +public protocol LoadableFromEmbeddedData { /// Set Element Type (Hash Prefix or Filter) associatedtype Element /// Decoded data type stored in the embedded json file - associatedtype EmbeddedDataSetType: Decodable, Sequence where EmbeddedDataSetType.Element == Self.Element + associatedtype EmbeddedDataSet: Decodable, Sequence where EmbeddedDataSet.Element == Self.Element init(revision: Int, items: some Sequence) } extension HashPrefixSet: LoadableFromEmbeddedData { - public typealias EmbeddedDataSetType = [String] + public typealias EmbeddedDataSet = [String] } extension FilterDictionary: LoadableFromEmbeddedData { - public typealias EmbeddedDataSetType = [Filter] + public typealias EmbeddedDataSet = [Filter] } diff --git a/Sources/MaliciousSiteProtection/Model/StoredData.swift b/Sources/MaliciousSiteProtection/Model/StoredData.swift index 7164bdca7..a4d294188 100644 --- a/Sources/MaliciousSiteProtection/Model/StoredData.swift +++ b/Sources/MaliciousSiteProtection/Model/StoredData.swift @@ -18,9 +18,9 @@ import Foundation -public protocol MaliciousSiteDataKeyProtocol: Hashable { - associatedtype EmbeddedDataSetType: Decodable - associatedtype DataSetType: IncrementallyUpdatableMaliciousSiteDataSet, LoadableFromEmbeddedData +public protocol MaliciousSiteDataKey: Hashable { + associatedtype EmbeddedDataSet: Decodable + associatedtype DataSet: IncrementallyUpdatableMaliciousSiteDataSet, LoadableFromEmbeddedData var dataType: DataManager.StoredDataType { get } var threatKind: ThreatKind { get } @@ -42,7 +42,7 @@ public extension DataManager { } } - var dataKey: any MaliciousSiteDataKeyProtocol { + var dataKey: any MaliciousSiteDataKey { switch self { case .hashPrefixSet(let key): key case .filterSet(let key): key @@ -70,8 +70,8 @@ public extension DataManager { } public extension DataManager.StoredDataType { - struct HashPrefixes: MaliciousSiteDataKeyProtocol { - public typealias DataSetType = HashPrefixSet + struct HashPrefixes: MaliciousSiteDataKey { + public typealias DataSet = HashPrefixSet public let threatKind: ThreatKind @@ -80,15 +80,15 @@ public extension DataManager.StoredDataType { } } } -extension MaliciousSiteDataKeyProtocol where Self == DataManager.StoredDataType.HashPrefixes { +extension MaliciousSiteDataKey where Self == DataManager.StoredDataType.HashPrefixes { static func hashPrefixes(threatKind: ThreatKind) -> Self { .init(threatKind: threatKind) } } public extension DataManager.StoredDataType { - struct FilterSet: MaliciousSiteDataKeyProtocol { - public typealias DataSetType = FilterDictionary + struct FilterSet: MaliciousSiteDataKey { + public typealias DataSet = FilterDictionary public let threatKind: ThreatKind @@ -97,7 +97,7 @@ public extension DataManager.StoredDataType { } } } -extension MaliciousSiteDataKeyProtocol where Self == DataManager.StoredDataType.FilterSet { +extension MaliciousSiteDataKey where Self == DataManager.StoredDataType.FilterSet { static func filterSet(threatKind: ThreatKind) -> Self { .init(threatKind: threatKind) } diff --git a/Sources/MaliciousSiteProtection/Services/DataManager.swift b/Sources/MaliciousSiteProtection/Services/DataManager.swift index 0e2154648..abbcf6d30 100644 --- a/Sources/MaliciousSiteProtection/Services/DataManager.swift +++ b/Sources/MaliciousSiteProtection/Services/DataManager.swift @@ -20,8 +20,8 @@ import Foundation import os public protocol DataManaging { - func dataSet(for key: DataKey) async -> DataKey.DataSetType - func store(_ dataSet: DataKey.DataSetType, for key: DataKey) async + func dataSet(for key: DataKey) async -> DataKey.DataSet + func store(_ dataSet: DataKey.DataSet, for key: DataKey) async } public actor DataManager: DataManaging { @@ -40,10 +40,10 @@ public actor DataManager: DataManaging { self.fileNameProvider = fileNameProvider } - public func dataSet(for key: DataKey) -> DataKey.DataSetType { + public func dataSet(for key: DataKey) -> DataKey.DataSet { let dataType = key.dataType // return cached dataSet if available - if let data = store[key.dataType] as? DataKey.DataSetType { + if let data = store[key.dataType] as? DataKey.DataSet { return data } @@ -61,14 +61,14 @@ public actor DataManager: DataManaging { return dataSet } - private func readStoredDataSet(for key: DataKey) -> DataKey.DataSetType? { + private func readStoredDataSet(for key: DataKey) -> DataKey.DataSet? { let dataType = key.dataType let fileName = fileNameProvider(dataType) guard let data = fileStore.read(from: fileName) else { return nil } - let storedDataSet: DataKey.DataSetType + let storedDataSet: DataKey.DataSet do { - storedDataSet = try JSONDecoder().decode(DataKey.DataSetType.self, from: data) + storedDataSet = try JSONDecoder().decode(DataKey.DataSet.self, from: data) } catch { Logger.dataManager.error("Error decoding \(fileName): \(error.localizedDescription)") return nil @@ -84,7 +84,7 @@ public actor DataManager: DataManaging { return storedDataSet } - public func store(_ dataSet: DataKey.DataSetType, for key: DataKey) { + public func store(_ dataSet: DataKey.DataSet, for key: DataKey) { let dataType = key.dataType let fileName = fileNameProvider(dataType) self.store[dataType] = dataSet diff --git a/Sources/MaliciousSiteProtection/Services/EmbeddedDataProvider.swift b/Sources/MaliciousSiteProtection/Services/EmbeddedDataProvider.swift index 084199c8b..6371b08b0 100644 --- a/Sources/MaliciousSiteProtection/Services/EmbeddedDataProvider.swift +++ b/Sources/MaliciousSiteProtection/Services/EmbeddedDataProvider.swift @@ -24,12 +24,12 @@ public protocol EmbeddedDataProviding { func url(for dataType: DataManager.StoredDataType) -> URL func hash(for dataType: DataManager.StoredDataType) -> String - func loadDataSet(for key: DataKey) -> DataKey.EmbeddedDataSetType + func loadDataSet(for key: DataKey) -> DataKey.EmbeddedDataSet } extension EmbeddedDataProviding { - public func loadDataSet(for key: DataKey) -> DataKey.EmbeddedDataSetType { + public func loadDataSet(for key: DataKey) -> DataKey.EmbeddedDataSet { let dataType = key.dataType let url = url(for: dataType) let data: Data @@ -42,7 +42,7 @@ extension EmbeddedDataProviding { fatalError("Could not load embedded data set at \(url.path): \(error)") } do { - let result = try JSONDecoder().decode(DataKey.EmbeddedDataSetType.self, from: data) + let result = try JSONDecoder().decode(DataKey.EmbeddedDataSet.self, from: data) return result } catch { fatalError("Could not decode embedded data set at \(url.path): \(error)") diff --git a/Sources/MaliciousSiteProtection/Services/UpdateManager.swift b/Sources/MaliciousSiteProtection/Services/UpdateManager.swift index 0dd7d1106..e1b6f0f4b 100644 --- a/Sources/MaliciousSiteProtection/Services/UpdateManager.swift +++ b/Sources/MaliciousSiteProtection/Services/UpdateManager.swift @@ -21,7 +21,7 @@ import Common import os public protocol UpdateManaging { - func updateData(for key: some MaliciousSiteDataKeyProtocol) async + func updateData(for key: some MaliciousSiteDataKey) async func updateFilterSet() async func updateHashPrefixes() async @@ -36,15 +36,15 @@ public struct UpdateManager: UpdateManaging { self.dataManager = dataManager } - public func updateData(for key: DataKey) async { + public func updateData(for key: DataKey) async { // load currently stored data set var dataSet = await dataManager.dataSet(for: key) let oldRevision = dataSet.revision // get change set from current revision from API - let changeSet: APIClient.ChangeSetResponse + let changeSet: APIClient.ChangeSetResponse do { - let request = DataKey.DataSetType.APIRequestType(threatKind: key.threatKind, revision: oldRevision) + let request = DataKey.DataSet.APIRequest(threatKind: key.threatKind, revision: oldRevision) changeSet = try await apiClient.load(request) } catch { Logger.updateManager.error("error fetching filter set: \(error)") diff --git a/Tests/MaliciousSiteProtectionTests/Mocks/MockMaliciousSiteProtectionDataManager.swift b/Tests/MaliciousSiteProtectionTests/Mocks/MockMaliciousSiteProtectionDataManager.swift index 963445abe..b371fb178 100644 --- a/Tests/MaliciousSiteProtectionTests/Mocks/MockMaliciousSiteProtectionDataManager.swift +++ b/Tests/MaliciousSiteProtectionTests/Mocks/MockMaliciousSiteProtectionDataManager.swift @@ -22,11 +22,11 @@ import MaliciousSiteProtection public class MockMaliciousSiteProtectionDataManager: MaliciousSiteProtection.DataManaging { var store = [MaliciousSiteProtection.DataManager.StoredDataType: Any]() - public func dataSet(for key: DataKey) async -> DataKey.DataSetType where DataKey : MaliciousSiteProtection.MaliciousSiteDataKeyProtocol { - store[key.dataType] as? DataKey.DataSetType ?? .init(revision: 0, items: []) + public func dataSet(for key: DataKey) async -> DataKey.DataSet where DataKey : MaliciousSiteProtection.MaliciousSiteDataKey { + store[key.dataType] as? DataKey.DataSet ?? .init(revision: 0, items: []) } - public func store(_ dataSet: DataKey.DataSetType, for key: DataKey) async where DataKey : MaliciousSiteProtection.MaliciousSiteDataKeyProtocol { + public func store(_ dataSet: DataKey.DataSet, for key: DataKey) async where DataKey : MaliciousSiteProtection.MaliciousSiteDataKey { store[key.dataType] = dataSet } diff --git a/Tests/MaliciousSiteProtectionTests/Mocks/MockMaliciousSiteProtectionEmbeddedDataProvider.swift b/Tests/MaliciousSiteProtectionTests/Mocks/MockMaliciousSiteProtectionEmbeddedDataProvider.swift index 2cf9d25da..2eaf6864d 100644 --- a/Tests/MaliciousSiteProtectionTests/Mocks/MockMaliciousSiteProtectionEmbeddedDataProvider.swift +++ b/Tests/MaliciousSiteProtectionTests/Mocks/MockMaliciousSiteProtectionEmbeddedDataProvider.swift @@ -38,14 +38,14 @@ public class MockMaliciousSiteProtectionEmbeddedDataProvider: MaliciousSiteProte "" } - public func loadDataSet(for key: DataKey) -> DataKey.EmbeddedDataSetType where DataKey : MaliciousSiteDataKeyProtocol { + public func loadDataSet(for key: DataKey) -> DataKey.EmbeddedDataSet where DataKey : MaliciousSiteDataKey { switch key.dataType { case .filterSet: self.loadFilterSetCalled = true - return Array(filterSet) as! DataKey.EmbeddedDataSetType + return Array(filterSet) as! DataKey.EmbeddedDataSet case .hashPrefixSet: self.loadHashPrefixesCalled = true - return Array(hashPrefixes) as! DataKey.EmbeddedDataSetType + return Array(hashPrefixes) as! DataKey.EmbeddedDataSet } } diff --git a/Tests/MaliciousSiteProtectionTests/Mocks/MockPhishingDetectionUpdateManager.swift b/Tests/MaliciousSiteProtectionTests/Mocks/MockPhishingDetectionUpdateManager.swift index c7fac96b4..49717ba68 100644 --- a/Tests/MaliciousSiteProtectionTests/Mocks/MockPhishingDetectionUpdateManager.swift +++ b/Tests/MaliciousSiteProtectionTests/Mocks/MockPhishingDetectionUpdateManager.swift @@ -25,7 +25,7 @@ public class MockPhishingDetectionUpdateManager: MaliciousSiteProtection.UpdateM var didUpdateHashPrefixes = false var completionHandler: (() -> Void)? - public func updateData(for key: some MaliciousSiteProtection.MaliciousSiteDataKeyProtocol) async { + public func updateData(for key: some MaliciousSiteProtection.MaliciousSiteDataKey) async { switch key.dataType { case .filterSet: await updateFilterSet() case .hashPrefixSet: await updateHashPrefixes() From f4b01a78b774791ed73d04a871205afbbe0deb1b Mon Sep 17 00:00:00 2001 From: Alexey Martemyanov Date: Thu, 28 Nov 2024 17:16:00 +0600 Subject: [PATCH 11/15] Update type naming --- .../API/APIClient.swift | 8 +++---- .../API/APIRequest.swift | 24 +++++++++---------- .../MaliciousSiteDetector.swift | 5 ++-- ...entallyUpdatableMaliciousSiteDataSet.swift | 2 +- 4 files changed, 20 insertions(+), 19 deletions(-) diff --git a/Sources/MaliciousSiteProtection/API/APIClient.swift b/Sources/MaliciousSiteProtection/API/APIClient.swift index 678171839..ce48a51b1 100644 --- a/Sources/MaliciousSiteProtection/API/APIClient.swift +++ b/Sources/MaliciousSiteProtection/API/APIClient.swift @@ -21,7 +21,7 @@ import Foundation import Networking public protocol APIClientProtocol { - func load(_ requestConfig: Request) async throws -> Request.ResponseType + func load(_ requestConfig: Request) async throws -> Request.Response } public extension APIClientProtocol where Self == APIClient { @@ -50,7 +50,7 @@ public extension APIClient { } var defaultHeaders: APIRequestV2.HeadersV2 { - .init(userAgent: APIRequest.Headers.userAgent) + .init(userAgent: Networking.APIRequest.Headers.userAgent) } enum APIPath { @@ -103,14 +103,14 @@ public struct APIClient: APIClientProtocol { self.service = service } - public func load(_ requestConfig: Request) async throws -> Request.ResponseType { + public func load(_ requestConfig: Request) async throws -> Request.Response { let requestType = requestConfig.requestType let headers = environment.headers(for: requestType) let url = environment.url(for: requestType) let apiRequest = APIRequestV2(url: url, method: .get, headers: headers) let response = try await service.fetch(request: apiRequest) - let result: Request.ResponseType = try response.decodeBody() + let result: Request.Response = try response.decodeBody() return result } diff --git a/Sources/MaliciousSiteProtection/API/APIRequest.swift b/Sources/MaliciousSiteProtection/API/APIRequest.swift index 0168b01cf..099684a87 100644 --- a/Sources/MaliciousSiteProtection/API/APIRequest.swift +++ b/Sources/MaliciousSiteProtection/API/APIRequest.swift @@ -18,11 +18,11 @@ import Foundation -public protocol APIRequestProtocol { - associatedtype ResponseType: Decodable +public protocol APIRequest { + associatedtype Response: Decodable var requestType: APIClient.Request { get } } -public protocol MaliciousSiteDataChangeSetAPIRequestProtocol: APIRequestProtocol { +public protocol ThreatDataChangeSetAPIRequest: APIRequest { init(threatKind: ThreatKind, revision: Int?) } @@ -34,8 +34,8 @@ public extension APIClient { } } public extension APIClient.Request { - struct HashPrefixes: MaliciousSiteDataChangeSetAPIRequestProtocol { - public typealias ResponseType = APIClient.Response.HashPrefixesChangeSet + struct HashPrefixes: ThreatDataChangeSetAPIRequest { + public typealias Response = APIClient.Response.HashPrefixesChangeSet public let threatKind: ThreatKind public let revision: Int? @@ -50,15 +50,15 @@ public extension APIClient.Request { } } } -extension APIRequestProtocol where Self == APIClient.Request.HashPrefixes { +extension APIRequest where Self == APIClient.Request.HashPrefixes { static func hashPrefixes(threatKind: ThreatKind, revision: Int?) -> Self { .init(threatKind: threatKind, revision: revision) } } public extension APIClient.Request { - struct FilterSet: MaliciousSiteDataChangeSetAPIRequestProtocol { - public typealias ResponseType = APIClient.Response.FiltersChangeSet + struct FilterSet: ThreatDataChangeSetAPIRequest { + public typealias Response = APIClient.Response.FiltersChangeSet public let threatKind: ThreatKind public let revision: Int? @@ -73,15 +73,15 @@ public extension APIClient.Request { } } } -extension APIRequestProtocol where Self == APIClient.Request.FilterSet { +extension APIRequest where Self == APIClient.Request.FilterSet { static func filterSet(threatKind: ThreatKind, revision: Int?) -> Self { .init(threatKind: threatKind, revision: revision) } } public extension APIClient.Request { - struct Matches: APIRequestProtocol { - public typealias ResponseType = APIClient.Response.Matches + struct Matches: APIRequest { + public typealias Response = APIClient.Response.Matches public let hashPrefix: String @@ -90,7 +90,7 @@ public extension APIClient.Request { } } } -extension APIRequestProtocol where Self == APIClient.Request.Matches { +extension APIRequest where Self == APIClient.Request.Matches { static func matches(hashPrefix: String) -> Self { .init(hashPrefix: hashPrefix) } diff --git a/Sources/MaliciousSiteProtection/MaliciousSiteDetector.swift b/Sources/MaliciousSiteProtection/MaliciousSiteDetector.swift index 82704fc34..0e7fb2dca 100644 --- a/Sources/MaliciousSiteProtection/MaliciousSiteDetector.swift +++ b/Sources/MaliciousSiteProtection/MaliciousSiteDetector.swift @@ -21,6 +21,9 @@ import CryptoKit import Foundation public protocol MaliciousSiteDetecting { + /// Evaluates the given URL to determine its threat level. + /// - Parameter url: The URL to evaluate. + /// - Returns: An optional ThreatKind indicating the type of threat, or nil if no threat is detected. func evaluate(_ url: URL) async -> ThreatKind? } @@ -74,8 +77,6 @@ public final class MaliciousSiteDetector: MaliciousSiteDetecting { } /// Evaluates the given URL to determine its threat level. - /// - Parameter url: The URL to evaluate. - /// - Returns: An optional ThreatKind indicating the type of threat, or nil if no threat is detected. public func evaluate(_ url: URL) async -> ThreatKind? { guard let canonicalHost = url.canonicalHost(), let canonicalUrl = url.canonicalURL() else { return .none } diff --git a/Sources/MaliciousSiteProtection/Model/IncrementallyUpdatableMaliciousSiteDataSet.swift b/Sources/MaliciousSiteProtection/Model/IncrementallyUpdatableMaliciousSiteDataSet.swift index 3d65553c4..d829e85d8 100644 --- a/Sources/MaliciousSiteProtection/Model/IncrementallyUpdatableMaliciousSiteDataSet.swift +++ b/Sources/MaliciousSiteProtection/Model/IncrementallyUpdatableMaliciousSiteDataSet.swift @@ -20,7 +20,7 @@ public protocol IncrementallyUpdatableMaliciousSiteDataSet: Codable { /// Set Element Type (Hash Prefix or Filter) associatedtype Element: Codable, Hashable /// API Request type used to fetch updates for the data set - associatedtype APIRequest: MaliciousSiteDataChangeSetAPIRequestProtocol, APIRequestProtocol where APIRequest.ResponseType == APIClient.ChangeSetResponse + associatedtype APIRequest: ThreatDataChangeSetAPIRequest where APIRequest.Response == APIClient.ChangeSetResponse var revision: Int { get set } From 89488e0d10eb5cf5e8ecf2e54ed1ce948bf021af Mon Sep 17 00:00:00 2001 From: Alexey Martemyanov Date: Thu, 28 Nov 2024 18:59:55 +0600 Subject: [PATCH 12/15] Improve naming and types visibility --- Package.swift | 3 +- .../API/APIClient.swift | 47 ++++++------ .../API/APIRequest.swift | 75 +++++++++++-------- .../MaliciousSiteDetector.swift | 8 +- .../Model/FilterDictionary.swift | 12 +-- .../Model/HashPrefixSet.swift | 12 +-- ...entallyUpdatableMaliciousSiteDataSet.swift | 28 +++---- .../Model/StoredData.swift | 16 ++-- .../PhishingDetectionDataActivities.swift | 6 +- .../Services/DataManager.swift | 6 +- .../Services/EmbeddedDataProvider.swift | 10 ++- .../Services/UpdateManager.swift | 12 ++- ...aliciousSiteProtectionAPIClientTests.swift | 2 +- ...iousSiteProtectionUpdateManagerTests.swift | 2 +- ...MockMaliciousSiteProtectionAPIClient.swift | 18 ++--- ...ckMaliciousSiteProtectionDataManager.swift | 8 +- ...usSiteProtectionEmbeddedDataProvider.swift | 14 ++-- .../MockPhishingDetectionUpdateManager.swift | 12 +-- 18 files changed, 156 insertions(+), 135 deletions(-) diff --git a/Package.swift b/Package.swift index c074ced7e..1841d6dbe 100644 --- a/Package.swift +++ b/Package.swift @@ -644,7 +644,8 @@ let package = Package( .testTarget( name: "DuckPlayerTests", dependencies: [ - "DuckPlayer" + "DuckPlayer", + "BrowserServicesKitTestsUtils", ] ), diff --git a/Sources/MaliciousSiteProtection/API/APIClient.swift b/Sources/MaliciousSiteProtection/API/APIClient.swift index ce48a51b1..d5f17d67f 100644 --- a/Sources/MaliciousSiteProtection/API/APIClient.swift +++ b/Sources/MaliciousSiteProtection/API/APIClient.swift @@ -20,22 +20,21 @@ import Common import Foundation import Networking -public protocol APIClientProtocol { - func load(_ requestConfig: Request) async throws -> Request.Response -} - -public extension APIClientProtocol where Self == APIClient { - static var production: APIClientProtocol { APIClient(environment: .production) } - static var staging: APIClientProtocol { APIClient(environment: .staging) } +extension APIClient { + // used internally for testing + protocol Mockable { + func load(_ requestConfig: Request) async throws -> Request.Response + } } +extension APIClient: APIClient.Mockable {} public protocol APIClientEnvironment { - func headers(for request: APIClient.Request) -> APIRequestV2.HeadersV2 - func url(for request: APIClient.Request) -> URL + func headers(for requestType: APIRequestType) -> APIRequestV2.HeadersV2 + func url(for requestType: APIRequestType) -> URL } -public extension APIClient { - enum DefaultEnvironment: APIClientEnvironment { +public extension MaliciousSiteDetector { + enum APIEnvironment: APIClientEnvironment { case production case staging @@ -65,8 +64,8 @@ public extension APIClient { static let hashPrefix = "hashPrefix" } - public func url(for request: APIClient.Request) -> URL { - switch request { + public func url(for requestType: APIRequestType) -> URL { + switch requestType { case .hashPrefixSet(let configuration): endpoint.appendingPathComponent(APIPath.hashPrefix).appendingParameters([ QueryParameter.category: configuration.threatKind.rawValue, @@ -82,35 +81,31 @@ public extension APIClient { } } - public func headers(for request: APIClient.Request) -> APIRequestV2.HeadersV2 { + public func headers(for requestType: APIRequestType) -> APIRequestV2.HeadersV2 { defaultHeaders } } } -public struct APIClient: APIClientProtocol { +struct APIClient { let environment: APIClientEnvironment private let service: APIService - public init(environment: Self.DefaultEnvironment = .production, service: APIService = DefaultAPIService(urlSession: .shared)) { - self.init(environment: environment as APIClientEnvironment, service: service) - } - - public init(environment: APIClientEnvironment, service: APIService) { + init(environment: APIClientEnvironment, service: APIService = DefaultAPIService(urlSession: .shared)) { self.environment = environment self.service = service } - public func load(_ requestConfig: Request) async throws -> Request.Response { + func load(_ requestConfig: R) async throws -> R.Response { let requestType = requestConfig.requestType let headers = environment.headers(for: requestType) let url = environment.url(for: requestType) let apiRequest = APIRequestV2(url: url, method: .get, headers: headers) let response = try await service.fetch(request: apiRequest) - let result: Request.Response = try response.decodeBody() + let result: R.Response = try response.decodeBody() return result } @@ -118,18 +113,18 @@ public struct APIClient: APIClientProtocol { } // MARK: - Convenience -extension APIClientProtocol { - public func filtersChangeSet(for threatKind: ThreatKind, revision: Int) async throws -> APIClient.Response.FiltersChangeSet { +extension APIClient.Mockable { + func filtersChangeSet(for threatKind: ThreatKind, revision: Int) async throws -> APIClient.Response.FiltersChangeSet { let result = try await load(.filterSet(threatKind: threatKind, revision: revision)) return result } - public func hashPrefixesChangeSet(for threatKind: ThreatKind, revision: Int) async throws -> APIClient.Response.HashPrefixesChangeSet { + func hashPrefixesChangeSet(for threatKind: ThreatKind, revision: Int) async throws -> APIClient.Response.HashPrefixesChangeSet { let result = try await load(.hashPrefixes(threatKind: threatKind, revision: revision)) return result } - public func matches(forHashPrefix hashPrefix: String) async throws -> APIClient.Response.Matches { + func matches(forHashPrefix hashPrefix: String) async throws -> APIClient.Response.Matches { let result = try await load(.matches(hashPrefix: hashPrefix)) return result } diff --git a/Sources/MaliciousSiteProtection/API/APIRequest.swift b/Sources/MaliciousSiteProtection/API/APIRequest.swift index 099684a87..1efc8c01b 100644 --- a/Sources/MaliciousSiteProtection/API/APIRequest.swift +++ b/Sources/MaliciousSiteProtection/API/APIRequest.swift @@ -18,79 +18,88 @@ import Foundation -public protocol APIRequest { - associatedtype Response: Decodable - var requestType: APIClient.Request { get } -} -public protocol ThreatDataChangeSetAPIRequest: APIRequest { - init(threatKind: ThreatKind, revision: Int?) +// Enumerated request type to delegate URLs forming to an API environment instance +public enum APIRequestType { + case hashPrefixSet(APIRequestType.HashPrefixes) + case filterSet(APIRequestType.FilterSet) + case matches(APIRequestType.Matches) } -public extension APIClient { - enum Request { - case hashPrefixSet(HashPrefixes) - case filterSet(FilterSet) - case matches(Matches) +extension APIClient { + // Protocol for defining typed requests with a specific response type. + protocol Request { + associatedtype Response: Decodable // Strongly-typed response type + var requestType: APIRequestType { get } // Enumerated type of request being made + } + + // Protocol for requests that modify a set of malicious site detection data + // (returning insertions/removals along with the updated revision) + protocol ChangeSetRequest: Request { + init(threatKind: ThreatKind, revision: Int?) } } -public extension APIClient.Request { - struct HashPrefixes: ThreatDataChangeSetAPIRequest { - public typealias Response = APIClient.Response.HashPrefixesChangeSet - public let threatKind: ThreatKind - public let revision: Int? +public extension APIRequestType { + struct HashPrefixes: APIClient.ChangeSetRequest { + typealias Response = APIClient.Response.HashPrefixesChangeSet + + let threatKind: ThreatKind + let revision: Int? - public init(threatKind: ThreatKind, revision: Int?) { + init(threatKind: ThreatKind, revision: Int?) { self.threatKind = threatKind self.revision = revision } - public var requestType: APIClient.Request { + var requestType: APIRequestType { .hashPrefixSet(self) } } } -extension APIRequest where Self == APIClient.Request.HashPrefixes { +/// extension to call generic `load(_: some Request)` method like this: `load(.hashPrefixes(…))` +extension APIClient.Request where Self == APIRequestType.HashPrefixes { static func hashPrefixes(threatKind: ThreatKind, revision: Int?) -> Self { .init(threatKind: threatKind, revision: revision) } } -public extension APIClient.Request { - struct FilterSet: ThreatDataChangeSetAPIRequest { - public typealias Response = APIClient.Response.FiltersChangeSet +public extension APIRequestType { + struct FilterSet: APIClient.ChangeSetRequest { + typealias Response = APIClient.Response.FiltersChangeSet - public let threatKind: ThreatKind - public let revision: Int? + let threatKind: ThreatKind + let revision: Int? - public init(threatKind: ThreatKind, revision: Int?) { + init(threatKind: ThreatKind, revision: Int?) { self.threatKind = threatKind self.revision = revision } - public var requestType: APIClient.Request { + var requestType: APIRequestType { .filterSet(self) } } } -extension APIRequest where Self == APIClient.Request.FilterSet { +/// extension to call generic `load(_: some Request)` method like this: `load(.filterSet(…))` +extension APIClient.Request where Self == APIRequestType.FilterSet { static func filterSet(threatKind: ThreatKind, revision: Int?) -> Self { .init(threatKind: threatKind, revision: revision) } } -public extension APIClient.Request { - struct Matches: APIRequest { - public typealias Response = APIClient.Response.Matches +public extension APIRequestType { + struct Matches: APIClient.Request { + typealias Response = APIClient.Response.Matches - public let hashPrefix: String + let hashPrefix: String - public var requestType: APIClient.Request { + var requestType: APIRequestType { .matches(self) } } } -extension APIRequest where Self == APIClient.Request.Matches { +/// extension to call generic `load(_: some Request)` method like this: `load(.matches(…))` +extension APIClient.Request where Self == APIRequestType.Matches { static func matches(hashPrefix: String) -> Self { .init(hashPrefix: hashPrefix) } diff --git a/Sources/MaliciousSiteProtection/MaliciousSiteDetector.swift b/Sources/MaliciousSiteProtection/MaliciousSiteDetector.swift index 0e7fb2dca..16bc386dc 100644 --- a/Sources/MaliciousSiteProtection/MaliciousSiteDetector.swift +++ b/Sources/MaliciousSiteProtection/MaliciousSiteDetector.swift @@ -39,11 +39,15 @@ public final class MaliciousSiteDetector: MaliciousSiteDetecting { static let hashPrefixParamLength: Int = 4 } - private let apiClient: APIClientProtocol + private let apiClient: APIClient.Mockable private let dataManager: DataManaging private let eventMapping: EventMapping - public init(apiClient: APIClientProtocol = APIClient(), dataManager: DataManaging, eventMapping: EventMapping) { + public convenience init(apiEnvironment: APIClientEnvironment, dataManager: DataManager, eventMapping: EventMapping) { + self.init(apiClient: APIClient(environment: apiEnvironment), dataManager: dataManager, eventMapping: eventMapping) + } + + init(apiClient: APIClient.Mockable, dataManager: DataManaging, eventMapping: EventMapping) { self.apiClient = apiClient self.dataManager = dataManager self.eventMapping = eventMapping diff --git a/Sources/MaliciousSiteProtection/Model/FilterDictionary.swift b/Sources/MaliciousSiteProtection/Model/FilterDictionary.swift index 2edd25b31..9e6026f41 100644 --- a/Sources/MaliciousSiteProtection/Model/FilterDictionary.swift +++ b/Sources/MaliciousSiteProtection/Model/FilterDictionary.swift @@ -18,10 +18,10 @@ import Foundation -public struct FilterDictionary: Codable, Equatable { +struct FilterDictionary: Codable, Equatable { /// Filter set revision - public var revision: Int + var revision: Int /// [Hash: [RegEx]] mapping /// @@ -36,9 +36,9 @@ public struct FilterDictionary: Codable, Equatable { /// ... /// } /// ``` - public var filters: [String: Set] + var filters: [String: Set] - public init(revision: Int, filters: [String: Set]) { + init(revision: Int, filters: [String: Set]) { self.filters = filters self.revision = revision } @@ -48,7 +48,7 @@ public struct FilterDictionary: Codable, Equatable { filters[hash] } - public mutating func subtract(_ itemsToDelete: Seq) where Seq.Element == Filter { + mutating func subtract(_ itemsToDelete: Seq) where Seq.Element == Filter { for filter in itemsToDelete { // Remove the filter from the Set stored in the Dictionary by hash used as a key. // If the Set becomes empty – remove the Set value from the Dictionary. @@ -74,7 +74,7 @@ public struct FilterDictionary: Codable, Equatable { } } - public mutating func formUnion(_ itemsToAdd: Seq) where Seq.Element == Filter { + mutating func formUnion(_ itemsToAdd: Seq) where Seq.Element == Filter { for filter in itemsToAdd { filters[filter.hash, default: []].insert(filter.regex) } diff --git a/Sources/MaliciousSiteProtection/Model/HashPrefixSet.swift b/Sources/MaliciousSiteProtection/Model/HashPrefixSet.swift index 9300e58b1..7aec5244d 100644 --- a/Sources/MaliciousSiteProtection/Model/HashPrefixSet.swift +++ b/Sources/MaliciousSiteProtection/Model/HashPrefixSet.swift @@ -19,21 +19,21 @@ import Foundation /// Structure storing a Set of hash prefixes ["6fe1e7c8","1d760415",...] and a revision of the set. -public struct HashPrefixSet: Codable, Equatable { +struct HashPrefixSet: Codable, Equatable { - public var revision: Int - public var set: Set + var revision: Int + var set: Set - public init(revision: Int, items: some Sequence) { + init(revision: Int, items: some Sequence) { self.revision = revision self.set = Set(items) } - public mutating func subtract(_ itemsToDelete: Seq) where Seq.Element == String { + mutating func subtract(_ itemsToDelete: Seq) where Seq.Element == String { set.subtract(itemsToDelete) } - public mutating func formUnion(_ itemsToAdd: Seq) where Seq.Element == String { + mutating func formUnion(_ itemsToAdd: Seq) where Seq.Element == String { set.formUnion(itemsToAdd) } diff --git a/Sources/MaliciousSiteProtection/Model/IncrementallyUpdatableMaliciousSiteDataSet.swift b/Sources/MaliciousSiteProtection/Model/IncrementallyUpdatableMaliciousSiteDataSet.swift index d829e85d8..3d5a74b14 100644 --- a/Sources/MaliciousSiteProtection/Model/IncrementallyUpdatableMaliciousSiteDataSet.swift +++ b/Sources/MaliciousSiteProtection/Model/IncrementallyUpdatableMaliciousSiteDataSet.swift @@ -1,5 +1,5 @@ // -// IncrementallyUpdatableMaliciousSiteDataSet.swift +// IncrementallyUpdatableDataSet.swift // // Copyright © 2024 DuckDuckGo. All rights reserved. // @@ -16,11 +16,11 @@ // limitations under the License. // -public protocol IncrementallyUpdatableMaliciousSiteDataSet: Codable { +protocol IncrementallyUpdatableDataSet: Codable { /// Set Element Type (Hash Prefix or Filter) associatedtype Element: Codable, Hashable /// API Request type used to fetch updates for the data set - associatedtype APIRequest: ThreatDataChangeSetAPIRequest where APIRequest.Response == APIClient.ChangeSetResponse + associatedtype APIRequest: APIClient.ChangeSetRequest where APIRequest.Response == APIClient.ChangeSetResponse var revision: Int { get set } @@ -33,8 +33,8 @@ public protocol IncrementallyUpdatableMaliciousSiteDataSet: Codable { mutating func apply(_ changeSet: APIClient.ChangeSetResponse) } -extension IncrementallyUpdatableMaliciousSiteDataSet { - public mutating func apply(_ changeSet: APIClient.ChangeSetResponse) { +extension IncrementallyUpdatableDataSet { + mutating func apply(_ changeSet: APIClient.ChangeSetResponse) { if changeSet.replace { self = .init(revision: changeSet.revision, items: changeSet.insert) } else { @@ -45,27 +45,27 @@ extension IncrementallyUpdatableMaliciousSiteDataSet { } } -extension HashPrefixSet: IncrementallyUpdatableMaliciousSiteDataSet { - public typealias Element = String - public typealias APIRequest = APIClient.Request.HashPrefixes +extension HashPrefixSet: IncrementallyUpdatableDataSet { + typealias Element = String + typealias APIRequest = APIRequestType.HashPrefixes - public static func apiRequest(for threatKind: ThreatKind, revision: Int) -> APIRequest { + static func apiRequest(for threatKind: ThreatKind, revision: Int) -> APIRequest { .hashPrefixes(threatKind: threatKind, revision: revision) } } -extension FilterDictionary: IncrementallyUpdatableMaliciousSiteDataSet { - public typealias Element = Filter - public typealias APIRequest = APIClient.Request.FilterSet +extension FilterDictionary: IncrementallyUpdatableDataSet { + typealias Element = Filter + typealias APIRequest = APIRequestType.FilterSet - public init(revision: Int, items: some Sequence) { + init(revision: Int, items: some Sequence) { let filtersDictionary = items.reduce(into: [String: Set]()) { result, filter in result[filter.hash, default: []].insert(filter.regex) } self.init(revision: revision, filters: filtersDictionary) } - public static func apiRequest(for threatKind: ThreatKind, revision: Int) -> APIRequest { + static func apiRequest(for threatKind: ThreatKind, revision: Int) -> APIRequest { .filterSet(threatKind: threatKind, revision: revision) } } diff --git a/Sources/MaliciousSiteProtection/Model/StoredData.swift b/Sources/MaliciousSiteProtection/Model/StoredData.swift index a4d294188..a064be076 100644 --- a/Sources/MaliciousSiteProtection/Model/StoredData.swift +++ b/Sources/MaliciousSiteProtection/Model/StoredData.swift @@ -18,9 +18,9 @@ import Foundation -public protocol MaliciousSiteDataKey: Hashable { +protocol MaliciousSiteDataKey: Hashable { associatedtype EmbeddedDataSet: Decodable - associatedtype DataSet: IncrementallyUpdatableMaliciousSiteDataSet, LoadableFromEmbeddedData + associatedtype DataSet: IncrementallyUpdatableDataSet, LoadableFromEmbeddedData var dataType: DataManager.StoredDataType { get } var threatKind: ThreatKind { get } @@ -71,11 +71,11 @@ public extension DataManager { public extension DataManager.StoredDataType { struct HashPrefixes: MaliciousSiteDataKey { - public typealias DataSet = HashPrefixSet + typealias DataSet = HashPrefixSet - public let threatKind: ThreatKind + let threatKind: ThreatKind - public var dataType: DataManager.StoredDataType { + var dataType: DataManager.StoredDataType { .hashPrefixSet(self) } } @@ -88,11 +88,11 @@ extension MaliciousSiteDataKey where Self == DataManager.StoredDataType.HashPref public extension DataManager.StoredDataType { struct FilterSet: MaliciousSiteDataKey { - public typealias DataSet = FilterDictionary + typealias DataSet = FilterDictionary - public let threatKind: ThreatKind + let threatKind: ThreatKind - public var dataType: DataManager.StoredDataType { + var dataType: DataManager.StoredDataType { .filterSet(self) } } diff --git a/Sources/MaliciousSiteProtection/PhishingDetectionDataActivities.swift b/Sources/MaliciousSiteProtection/PhishingDetectionDataActivities.swift index bcf06d518..944bc963c 100644 --- a/Sources/MaliciousSiteProtection/PhishingDetectionDataActivities.swift +++ b/Sources/MaliciousSiteProtection/PhishingDetectionDataActivities.swift @@ -70,7 +70,11 @@ public class PhishingDetectionDataActivities: PhishingDetectionDataActivityHandl private var schedulers: [BackgroundActivityScheduler] private var running: Bool = false - public init(hashPrefixInterval: TimeInterval = 20 * 60, filterSetInterval: TimeInterval = 12 * 60 * 60, updateManager: UpdateManaging) { + public convenience init(hashPrefixInterval: TimeInterval = 20 * 60, filterSetInterval: TimeInterval = 12 * 60 * 60, updateManager: UpdateManager) { + self.init(hashPrefixInterval: hashPrefixInterval, filterSetInterval: filterSetInterval, updateManager: updateManager) + } + + init(hashPrefixInterval: TimeInterval = 20 * 60, filterSetInterval: TimeInterval = 12 * 60 * 60, updateManager: UpdateManaging) { let hashPrefixScheduler = BackgroundActivityScheduler( interval: hashPrefixInterval, identifier: "hashPrefixes.update", diff --git a/Sources/MaliciousSiteProtection/Services/DataManager.swift b/Sources/MaliciousSiteProtection/Services/DataManager.swift index abbcf6d30..8e4426dd1 100644 --- a/Sources/MaliciousSiteProtection/Services/DataManager.swift +++ b/Sources/MaliciousSiteProtection/Services/DataManager.swift @@ -19,7 +19,7 @@ import Foundation import os -public protocol DataManaging { +protocol DataManaging { func dataSet(for key: DataKey) async -> DataKey.DataSet func store(_ dataSet: DataKey.DataSet, for key: DataKey) async } @@ -40,7 +40,7 @@ public actor DataManager: DataManaging { self.fileNameProvider = fileNameProvider } - public func dataSet(for key: DataKey) -> DataKey.DataSet { + func dataSet(for key: DataKey) -> DataKey.DataSet { let dataType = key.dataType // return cached dataSet if available if let data = store[key.dataType] as? DataKey.DataSet { @@ -84,7 +84,7 @@ public actor DataManager: DataManaging { return storedDataSet } - public func store(_ dataSet: DataKey.DataSet, for key: DataKey) { + func store(_ dataSet: DataKey.DataSet, for key: DataKey) { let dataType = key.dataType let fileName = fileNameProvider(dataType) self.store[dataType] = dataSet diff --git a/Sources/MaliciousSiteProtection/Services/EmbeddedDataProvider.swift b/Sources/MaliciousSiteProtection/Services/EmbeddedDataProvider.swift index 6371b08b0..9352ac250 100644 --- a/Sources/MaliciousSiteProtection/Services/EmbeddedDataProvider.swift +++ b/Sources/MaliciousSiteProtection/Services/EmbeddedDataProvider.swift @@ -24,17 +24,17 @@ public protocol EmbeddedDataProviding { func url(for dataType: DataManager.StoredDataType) -> URL func hash(for dataType: DataManager.StoredDataType) -> String - func loadDataSet(for key: DataKey) -> DataKey.EmbeddedDataSet + func data(withContentsOf url: URL) throws -> Data } extension EmbeddedDataProviding { - public func loadDataSet(for key: DataKey) -> DataKey.EmbeddedDataSet { + func loadDataSet(for key: DataKey) -> DataKey.EmbeddedDataSet { let dataType = key.dataType let url = url(for: dataType) let data: Data do { - data = try Data(contentsOf: url) + data = try self.data(withContentsOf: url) #if DEBUG assert(data.sha256 == hash(for: dataType), "SHA mismatch for \(url.path)") #endif @@ -49,4 +49,8 @@ extension EmbeddedDataProviding { } } + public func data(withContentsOf url: URL) throws -> Data { + try Data(contentsOf: url) + } + } diff --git a/Sources/MaliciousSiteProtection/Services/UpdateManager.swift b/Sources/MaliciousSiteProtection/Services/UpdateManager.swift index e1b6f0f4b..021725d7e 100644 --- a/Sources/MaliciousSiteProtection/Services/UpdateManager.swift +++ b/Sources/MaliciousSiteProtection/Services/UpdateManager.swift @@ -20,7 +20,7 @@ import Foundation import Common import os -public protocol UpdateManaging { +protocol UpdateManaging { func updateData(for key: some MaliciousSiteDataKey) async func updateFilterSet() async @@ -28,15 +28,19 @@ public protocol UpdateManaging { } public struct UpdateManager: UpdateManaging { - private let apiClient: APIClientProtocol + private let apiClient: APIClient.Mockable private let dataManager: DataManaging - public init(apiClient: APIClientProtocol, dataManager: DataManaging) { + public init(apiEnvironment: APIClientEnvironment, dataManager: DataManager) { + self.init(apiClient: APIClient(environment: apiEnvironment), dataManager: dataManager) + } + + init(apiClient: APIClient.Mockable, dataManager: DataManaging) { self.apiClient = apiClient self.dataManager = dataManager } - public func updateData(for key: DataKey) async { + func updateData(for key: DataKey) async { // load currently stored data set var dataSet = await dataManager.dataSet(for: key) let oldRevision = dataSet.revision diff --git a/Tests/MaliciousSiteProtectionTests/MaliciousSiteProtectionAPIClientTests.swift b/Tests/MaliciousSiteProtectionTests/MaliciousSiteProtectionAPIClientTests.swift index d32d264ea..fcea80939 100644 --- a/Tests/MaliciousSiteProtectionTests/MaliciousSiteProtectionAPIClientTests.swift +++ b/Tests/MaliciousSiteProtectionTests/MaliciousSiteProtectionAPIClientTests.swift @@ -30,7 +30,7 @@ final class MaliciousSiteProtectionAPIClientTests: XCTestCase { override func setUp() { super.setUp() mockService = MockAPIService() - client = .init(environment: .staging, service: mockService) + client = .init(environment: MaliciousSiteDetector.APIEnvironment.staging, service: mockService) } override func tearDown() { diff --git a/Tests/MaliciousSiteProtectionTests/MaliciousSiteProtectionUpdateManagerTests.swift b/Tests/MaliciousSiteProtectionTests/MaliciousSiteProtectionUpdateManagerTests.swift index 5eabf6390..3a190e18e 100644 --- a/Tests/MaliciousSiteProtectionTests/MaliciousSiteProtectionUpdateManagerTests.swift +++ b/Tests/MaliciousSiteProtectionTests/MaliciousSiteProtectionUpdateManagerTests.swift @@ -24,7 +24,7 @@ import XCTest class MaliciousSiteProtectionUpdateManagerTests: XCTestCase { var updateManager: MaliciousSiteProtection.UpdateManager! var dataManager: MaliciousSiteProtection.DataManaging! - var apiClient: MaliciousSiteProtection.APIClientProtocol! + var apiClient: MaliciousSiteProtection.APIClient.Mockable! override func setUp() async throws { apiClient = MockMaliciousSiteProtectionAPIClient() diff --git a/Tests/MaliciousSiteProtectionTests/Mocks/MockMaliciousSiteProtectionAPIClient.swift b/Tests/MaliciousSiteProtectionTests/Mocks/MockMaliciousSiteProtectionAPIClient.swift index 874e4fdbd..473030ca9 100644 --- a/Tests/MaliciousSiteProtectionTests/Mocks/MockMaliciousSiteProtectionAPIClient.swift +++ b/Tests/MaliciousSiteProtectionTests/Mocks/MockMaliciousSiteProtectionAPIClient.swift @@ -17,13 +17,13 @@ // import Foundation -import MaliciousSiteProtection +@testable import MaliciousSiteProtection -public class MockMaliciousSiteProtectionAPIClient: MaliciousSiteProtection.APIClientProtocol { - public var updateHashPrefixesWasCalled: Bool = false - public var updateFilterSetsWasCalled: Bool = false +class MockMaliciousSiteProtectionAPIClient: MaliciousSiteProtection.APIClient.Mockable { + var updateHashPrefixesWasCalled: Bool = false + var updateFilterSetsWasCalled: Bool = false - private var filterRevisions: [Int: APIClient.Response.FiltersChangeSet] = [ + var filterRevisions: [Int: APIClient.Response.FiltersChangeSet] = [ 0: .init(insert: [ Filter(hash: "testhash1", regex: ".*example.*"), Filter(hash: "testhash2", regex: ".*test.*") @@ -73,14 +73,14 @@ public class MockMaliciousSiteProtectionAPIClient: MaliciousSiteProtection.APICl ], revision: 6, replace: true), ] - public func load(_ requestConfig: Request) async throws -> Request.ResponseType where Request: APIRequestProtocol { + func load(_ requestConfig: Request) async throws -> Request.Response where Request : APIClient.Request { switch requestConfig.requestType { case .hashPrefixSet(let configuration): - return _hashPrefixesChangeSet(for: configuration.threatKind, revision: configuration.revision ?? 0) as! Request.ResponseType + return _hashPrefixesChangeSet(for: configuration.threatKind, revision: configuration.revision ?? 0) as! Request.Response case .filterSet(let configuration): - return _filtersChangeSet(for: configuration.threatKind, revision: configuration.revision ?? 0) as! Request.ResponseType + return _filtersChangeSet(for: configuration.threatKind, revision: configuration.revision ?? 0) as! Request.Response case .matches(let configuration): - return _matches(forHashPrefix: configuration.hashPrefix) as! Request.ResponseType + return _matches(forHashPrefix: configuration.hashPrefix) as! Request.Response } } func _filtersChangeSet(for threatKind: MaliciousSiteProtection.ThreatKind, revision: Int) -> MaliciousSiteProtection.APIClient.Response.FiltersChangeSet { diff --git a/Tests/MaliciousSiteProtectionTests/Mocks/MockMaliciousSiteProtectionDataManager.swift b/Tests/MaliciousSiteProtectionTests/Mocks/MockMaliciousSiteProtectionDataManager.swift index b371fb178..ab27f5f2a 100644 --- a/Tests/MaliciousSiteProtectionTests/Mocks/MockMaliciousSiteProtectionDataManager.swift +++ b/Tests/MaliciousSiteProtectionTests/Mocks/MockMaliciousSiteProtectionDataManager.swift @@ -17,16 +17,16 @@ // import Foundation -import MaliciousSiteProtection +@testable import MaliciousSiteProtection -public class MockMaliciousSiteProtectionDataManager: MaliciousSiteProtection.DataManaging { +class MockMaliciousSiteProtectionDataManager: MaliciousSiteProtection.DataManaging { var store = [MaliciousSiteProtection.DataManager.StoredDataType: Any]() - public func dataSet(for key: DataKey) async -> DataKey.DataSet where DataKey : MaliciousSiteProtection.MaliciousSiteDataKey { + func dataSet(for key: DataKey) async -> DataKey.DataSet where DataKey : MaliciousSiteProtection.MaliciousSiteDataKey { store[key.dataType] as? DataKey.DataSet ?? .init(revision: 0, items: []) } - public func store(_ dataSet: DataKey.DataSet, for key: DataKey) async where DataKey : MaliciousSiteProtection.MaliciousSiteDataKey { + func store(_ dataSet: DataKey.DataSet, for key: DataKey) async where DataKey : MaliciousSiteProtection.MaliciousSiteDataKey { store[key.dataType] = dataSet } diff --git a/Tests/MaliciousSiteProtectionTests/Mocks/MockMaliciousSiteProtectionEmbeddedDataProvider.swift b/Tests/MaliciousSiteProtectionTests/Mocks/MockMaliciousSiteProtectionEmbeddedDataProvider.swift index 2eaf6864d..84542ed4f 100644 --- a/Tests/MaliciousSiteProtectionTests/Mocks/MockMaliciousSiteProtectionEmbeddedDataProvider.swift +++ b/Tests/MaliciousSiteProtectionTests/Mocks/MockMaliciousSiteProtectionEmbeddedDataProvider.swift @@ -17,28 +17,28 @@ // import Foundation -import MaliciousSiteProtection +@testable import MaliciousSiteProtection -public class MockMaliciousSiteProtectionEmbeddedDataProvider: MaliciousSiteProtection.EmbeddedDataProviding { - public var embeddedRevision: Int = 65 +class MockMaliciousSiteProtectionEmbeddedDataProvider: MaliciousSiteProtection.EmbeddedDataProviding { + var embeddedRevision: Int = 65 var loadHashPrefixesCalled: Bool = false var loadFilterSetCalled: Bool = true var hashPrefixes = Set(["aabb"]) var filterSet = Set([Filter(hash: "dummyhash", regex: "dummyregex")]) - public func revision(for detectionKind: MaliciousSiteProtection.DataManager.StoredDataType) -> Int { + func revision(for detectionKind: MaliciousSiteProtection.DataManager.StoredDataType) -> Int { embeddedRevision } - public func url(for detectionKind: MaliciousSiteProtection.DataManager.StoredDataType) -> URL { + func url(for detectionKind: MaliciousSiteProtection.DataManager.StoredDataType) -> URL { URL.empty } - public func hash(for detectionKind: MaliciousSiteProtection.DataManager.StoredDataType) -> String { + func hash(for detectionKind: MaliciousSiteProtection.DataManager.StoredDataType) -> String { "" } - public func loadDataSet(for key: DataKey) -> DataKey.EmbeddedDataSet where DataKey : MaliciousSiteDataKey { + func loadDataSet(for key: DataKey) -> DataKey.EmbeddedDataSet where DataKey : MaliciousSiteDataKey { switch key.dataType { case .filterSet: self.loadFilterSetCalled = true diff --git a/Tests/MaliciousSiteProtectionTests/Mocks/MockPhishingDetectionUpdateManager.swift b/Tests/MaliciousSiteProtectionTests/Mocks/MockPhishingDetectionUpdateManager.swift index 49717ba68..605b2118d 100644 --- a/Tests/MaliciousSiteProtectionTests/Mocks/MockPhishingDetectionUpdateManager.swift +++ b/Tests/MaliciousSiteProtectionTests/Mocks/MockPhishingDetectionUpdateManager.swift @@ -17,32 +17,32 @@ // import Foundation -import MaliciousSiteProtection +@testable import MaliciousSiteProtection -public class MockPhishingDetectionUpdateManager: MaliciousSiteProtection.UpdateManaging { +class MockPhishingDetectionUpdateManager: MaliciousSiteProtection.UpdateManaging { var didUpdateFilterSet = false var didUpdateHashPrefixes = false var completionHandler: (() -> Void)? - public func updateData(for key: some MaliciousSiteProtection.MaliciousSiteDataKey) async { + func updateData(for key: some MaliciousSiteProtection.MaliciousSiteDataKey) async { switch key.dataType { case .filterSet: await updateFilterSet() case .hashPrefixSet: await updateHashPrefixes() } } - public func updateFilterSet() async { + func updateFilterSet() async { didUpdateFilterSet = true checkCompletion() } - public func updateHashPrefixes() async { + func updateHashPrefixes() async { didUpdateHashPrefixes = true checkCompletion() } - private func checkCompletion() { + func checkCompletion() { if didUpdateFilterSet && didUpdateHashPrefixes { completionHandler?() } From 08554913e5b57a30b82a6480afe6a001338289b2 Mon Sep 17 00:00:00 2001 From: Alexey Martemyanov Date: Thu, 28 Nov 2024 19:11:30 +0600 Subject: [PATCH 13/15] Evaluate local filters first before making the API call; Improve docs --- .../MaliciousSiteDetector.swift | 43 ++++++++++++------- 1 file changed, 28 insertions(+), 15 deletions(-) diff --git a/Sources/MaliciousSiteProtection/MaliciousSiteDetector.swift b/Sources/MaliciousSiteProtection/MaliciousSiteDetector.swift index 16bc386dc..6ca115a05 100644 --- a/Sources/MaliciousSiteProtection/MaliciousSiteDetector.swift +++ b/Sources/MaliciousSiteProtection/MaliciousSiteDetector.swift @@ -21,9 +21,9 @@ import CryptoKit import Foundation public protocol MaliciousSiteDetecting { - /// Evaluates the given URL to determine its threat level. + /// Evaluates the given URL to determine its malicious category (e.g., phishing, malware). /// - Parameter url: The URL to evaluate. - /// - Returns: An optional ThreatKind indicating the type of threat, or nil if no threat is detected. + /// - Returns: An optional `ThreatKind` indicating the type of threat, or `.none` if no threat is detected. func evaluate(_ url: URL) async -> ThreatKind? } @@ -80,7 +80,7 @@ public final class MaliciousSiteDetector: MaliciousSiteDetecting { return nil } - /// Evaluates the given URL to determine its threat level. + /// Evaluates the given URL to determine its malicious category (e.g., phishing, malware). public func evaluate(_ url: URL) async -> ThreatKind? { guard let canonicalHost = url.canonicalHost(), let canonicalUrl = url.canonicalURL() else { return .none } @@ -88,28 +88,41 @@ public final class MaliciousSiteDetector: MaliciousSiteDetecting { let hostHash = canonicalHost.sha256 let hashPrefix = String(hostHash.prefix(Constants.hashPrefixStoreLength)) - for threatKind in ThreatKind.allCases /* phishing, malware.. */ { + // 1. Check for matching hash prefixes. + // The hash prefix list serves as a representation of the entire database: + // every malicious website will have a hash prefix that it collides with. + var hashPrefixMatchingThreatKinds = [ThreatKind]() + for threatKind in ThreatKind.allCases { // e.g., phishing, malware, etc. let hashPrefixes = await dataManager.dataSet(for: .hashPrefixes(threatKind: threatKind)) - guard hashPrefixes.contains(hashPrefix) else { continue } + if hashPrefixes.contains(hashPrefix) { + hashPrefixMatchingThreatKinds.append(threatKind) + } + } - // Check local filterSet first + // Return no threats if no matching hash prefixes are found in the database. + guard !hashPrefixMatchingThreatKinds.isEmpty else { return .none } + + // 2. Check local Filter Sets. + // The filter set acts as a local cache of some database entries, containing + // the 5000 most common threats (or those most likely to collide with daily + // browsing behaviors, based on Clickhouse's top 10k, ranked by Netcraft's risk rating). + for threatKind in hashPrefixMatchingThreatKinds { if await checkLocalFilters(hostHash: hostHash, canonicalUrl: canonicalUrl, for: threatKind) { eventMapping.fire(.errorPageShown(clientSideHit: true)) return threatKind } + } - // If nothing found, hit the API to get matches - let match = await checkApiMatches(hostHash: hostHash, canonicalUrl: canonicalUrl) - if let match { - eventMapping.fire(.errorPageShown(clientSideHit: false)) - return match.category.map(ThreatKind.init) ?? threatKind - } - - // the API detects both phishing and malware so if it didn‘t find any matches it‘s safe to return early. - return nil + // 3. If no locally cached filters matched, we will still make a request to the API + // to check for potential matches on our backend. + let match = await checkApiMatches(hostHash: hostHash, canonicalUrl: canonicalUrl) + if let match { + eventMapping.fire(.errorPageShown(clientSideHit: false)) + return match.category.map(ThreatKind.init) ?? hashPrefixMatchingThreatKinds[0] } return .none } + } From e18392f52e8acfec775138738600082d580d1b60 Mon Sep 17 00:00:00 2001 From: Alexey Martemyanov Date: Thu, 28 Nov 2024 19:19:54 +0600 Subject: [PATCH 14/15] fix linter issues --- .../MaliciousSiteDetector.swift | 11 ++++++----- Sources/MaliciousSiteProtection/Model/Event.swift | 4 ++-- .../Model/FilterDictionary.swift | 5 ----- ...aSet.swift => IncrementallyUpdatableDataSet.swift} | 0 .../PhishingDetectionDataActivities.swift | 2 +- ...iciousSiteProtectionEmbeddedDataProviderTest.swift | 2 +- .../Mocks/MockMaliciousSiteProtectionAPIClient.swift | 2 +- .../MockMaliciousSiteProtectionDataManager.swift | 4 ++-- ...kMaliciousSiteProtectionEmbeddedDataProvider.swift | 6 +++--- .../Mocks/MockPhishingDetectionUpdateManager.swift | 2 +- 10 files changed, 17 insertions(+), 21 deletions(-) rename Sources/MaliciousSiteProtection/Model/{IncrementallyUpdatableMaliciousSiteDataSet.swift => IncrementallyUpdatableDataSet.swift} (100%) diff --git a/Sources/MaliciousSiteProtection/MaliciousSiteDetector.swift b/Sources/MaliciousSiteProtection/MaliciousSiteDetector.swift index 6ca115a05..d38fc4c12 100644 --- a/Sources/MaliciousSiteProtection/MaliciousSiteDetector.swift +++ b/Sources/MaliciousSiteProtection/MaliciousSiteDetector.swift @@ -107,8 +107,9 @@ public final class MaliciousSiteDetector: MaliciousSiteDetecting { // the 5000 most common threats (or those most likely to collide with daily // browsing behaviors, based on Clickhouse's top 10k, ranked by Netcraft's risk rating). for threatKind in hashPrefixMatchingThreatKinds { - if await checkLocalFilters(hostHash: hostHash, canonicalUrl: canonicalUrl, for: threatKind) { - eventMapping.fire(.errorPageShown(clientSideHit: true)) + let matches = await checkLocalFilters(hostHash: hostHash, canonicalUrl: canonicalUrl, for: threatKind) + if matches { + eventMapping.fire(.errorPageShown(clientSideHit: true, threatKind: threatKind)) return threatKind } } @@ -117,12 +118,12 @@ public final class MaliciousSiteDetector: MaliciousSiteDetecting { // to check for potential matches on our backend. let match = await checkApiMatches(hostHash: hostHash, canonicalUrl: canonicalUrl) if let match { - eventMapping.fire(.errorPageShown(clientSideHit: false)) - return match.category.map(ThreatKind.init) ?? hashPrefixMatchingThreatKinds[0] + let threatKind = match.category.flatMap(ThreatKind.init) ?? hashPrefixMatchingThreatKinds[0] + eventMapping.fire(.errorPageShown(clientSideHit: false, threatKind: threatKind)) + return threatKind } return .none } - } diff --git a/Sources/MaliciousSiteProtection/Model/Event.swift b/Sources/MaliciousSiteProtection/Model/Event.swift index 31eab462a..8903f4d70 100644 --- a/Sources/MaliciousSiteProtection/Model/Event.swift +++ b/Sources/MaliciousSiteProtection/Model/Event.swift @@ -27,7 +27,7 @@ public extension PixelKit { } public enum Event: PixelKitEventV2 { - case errorPageShown(clientSideHit: Bool) + case errorPageShown(clientSideHit: Bool, threatKind: ThreatKind) case visitSite case iframeLoaded case updateTaskFailed48h(error: Error?) @@ -50,7 +50,7 @@ public enum Event: PixelKitEventV2 { public var parameters: [String: String]? { switch self { - case .errorPageShown(let clientSideHit): + case .errorPageShown(let clientSideHit, threatKind: _): return [PixelKit.Parameters.clientSideHit: String(clientSideHit)] case .visitSite: return [:] diff --git a/Sources/MaliciousSiteProtection/Model/FilterDictionary.swift b/Sources/MaliciousSiteProtection/Model/FilterDictionary.swift index 9e6026f41..b67cd82ef 100644 --- a/Sources/MaliciousSiteProtection/Model/FilterDictionary.swift +++ b/Sources/MaliciousSiteProtection/Model/FilterDictionary.swift @@ -38,11 +38,6 @@ struct FilterDictionary: Codable, Equatable { /// ``` var filters: [String: Set] - init(revision: Int, filters: [String: Set]) { - self.filters = filters - self.revision = revision - } - /// Subscript to access regex patterns by SHA256 host name hash subscript(hash: String) -> Set? { filters[hash] diff --git a/Sources/MaliciousSiteProtection/Model/IncrementallyUpdatableMaliciousSiteDataSet.swift b/Sources/MaliciousSiteProtection/Model/IncrementallyUpdatableDataSet.swift similarity index 100% rename from Sources/MaliciousSiteProtection/Model/IncrementallyUpdatableMaliciousSiteDataSet.swift rename to Sources/MaliciousSiteProtection/Model/IncrementallyUpdatableDataSet.swift diff --git a/Sources/MaliciousSiteProtection/PhishingDetectionDataActivities.swift b/Sources/MaliciousSiteProtection/PhishingDetectionDataActivities.swift index 944bc963c..0b811c5b8 100644 --- a/Sources/MaliciousSiteProtection/PhishingDetectionDataActivities.swift +++ b/Sources/MaliciousSiteProtection/PhishingDetectionDataActivities.swift @@ -71,7 +71,7 @@ public class PhishingDetectionDataActivities: PhishingDetectionDataActivityHandl private var running: Bool = false public convenience init(hashPrefixInterval: TimeInterval = 20 * 60, filterSetInterval: TimeInterval = 12 * 60 * 60, updateManager: UpdateManager) { - self.init(hashPrefixInterval: hashPrefixInterval, filterSetInterval: filterSetInterval, updateManager: updateManager) + self.init(hashPrefixInterval: hashPrefixInterval, filterSetInterval: filterSetInterval, updateManager: updateManager as UpdateManaging) } init(hashPrefixInterval: TimeInterval = 20 * 60, filterSetInterval: TimeInterval = 12 * 60 * 60, updateManager: UpdateManaging) { diff --git a/Tests/MaliciousSiteProtectionTests/MaliciousSiteProtectionEmbeddedDataProviderTest.swift b/Tests/MaliciousSiteProtectionTests/MaliciousSiteProtectionEmbeddedDataProviderTest.swift index 600352f52..1e3e0df40 100644 --- a/Tests/MaliciousSiteProtectionTests/MaliciousSiteProtectionEmbeddedDataProviderTest.swift +++ b/Tests/MaliciousSiteProtectionTests/MaliciousSiteProtectionEmbeddedDataProviderTest.swift @@ -35,7 +35,7 @@ class MaliciousSiteProtectionEmbeddedDataProviderTest: XCTestCase { Bundle.module.url(forResource: "\(key.threatKind)HashPrefixes", withExtension: "json")! } } - + func hash(for dataType: MaliciousSiteProtection.DataManager.StoredDataType) -> String { switch dataType { case .filterSet(let key): diff --git a/Tests/MaliciousSiteProtectionTests/Mocks/MockMaliciousSiteProtectionAPIClient.swift b/Tests/MaliciousSiteProtectionTests/Mocks/MockMaliciousSiteProtectionAPIClient.swift index 473030ca9..1a580f5c8 100644 --- a/Tests/MaliciousSiteProtectionTests/Mocks/MockMaliciousSiteProtectionAPIClient.swift +++ b/Tests/MaliciousSiteProtectionTests/Mocks/MockMaliciousSiteProtectionAPIClient.swift @@ -73,7 +73,7 @@ class MockMaliciousSiteProtectionAPIClient: MaliciousSiteProtection.APIClient.Mo ], revision: 6, replace: true), ] - func load(_ requestConfig: Request) async throws -> Request.Response where Request : APIClient.Request { + func load(_ requestConfig: Request) async throws -> Request.Response where Request: APIClient.Request { switch requestConfig.requestType { case .hashPrefixSet(let configuration): return _hashPrefixesChangeSet(for: configuration.threatKind, revision: configuration.revision ?? 0) as! Request.Response diff --git a/Tests/MaliciousSiteProtectionTests/Mocks/MockMaliciousSiteProtectionDataManager.swift b/Tests/MaliciousSiteProtectionTests/Mocks/MockMaliciousSiteProtectionDataManager.swift index ab27f5f2a..ea9a63a71 100644 --- a/Tests/MaliciousSiteProtectionTests/Mocks/MockMaliciousSiteProtectionDataManager.swift +++ b/Tests/MaliciousSiteProtectionTests/Mocks/MockMaliciousSiteProtectionDataManager.swift @@ -22,11 +22,11 @@ import Foundation class MockMaliciousSiteProtectionDataManager: MaliciousSiteProtection.DataManaging { var store = [MaliciousSiteProtection.DataManager.StoredDataType: Any]() - func dataSet(for key: DataKey) async -> DataKey.DataSet where DataKey : MaliciousSiteProtection.MaliciousSiteDataKey { + func dataSet(for key: DataKey) async -> DataKey.DataSet where DataKey: MaliciousSiteProtection.MaliciousSiteDataKey { store[key.dataType] as? DataKey.DataSet ?? .init(revision: 0, items: []) } - func store(_ dataSet: DataKey.DataSet, for key: DataKey) async where DataKey : MaliciousSiteProtection.MaliciousSiteDataKey { + func store(_ dataSet: DataKey.DataSet, for key: DataKey) async where DataKey: MaliciousSiteProtection.MaliciousSiteDataKey { store[key.dataType] = dataSet } diff --git a/Tests/MaliciousSiteProtectionTests/Mocks/MockMaliciousSiteProtectionEmbeddedDataProvider.swift b/Tests/MaliciousSiteProtectionTests/Mocks/MockMaliciousSiteProtectionEmbeddedDataProvider.swift index 84542ed4f..0beb44dd7 100644 --- a/Tests/MaliciousSiteProtectionTests/Mocks/MockMaliciousSiteProtectionEmbeddedDataProvider.swift +++ b/Tests/MaliciousSiteProtectionTests/Mocks/MockMaliciousSiteProtectionEmbeddedDataProvider.swift @@ -29,16 +29,16 @@ class MockMaliciousSiteProtectionEmbeddedDataProvider: MaliciousSiteProtection.E func revision(for detectionKind: MaliciousSiteProtection.DataManager.StoredDataType) -> Int { embeddedRevision } - + func url(for detectionKind: MaliciousSiteProtection.DataManager.StoredDataType) -> URL { URL.empty } - + func hash(for detectionKind: MaliciousSiteProtection.DataManager.StoredDataType) -> String { "" } - func loadDataSet(for key: DataKey) -> DataKey.EmbeddedDataSet where DataKey : MaliciousSiteDataKey { + func loadDataSet(for key: DataKey) -> DataKey.EmbeddedDataSet where DataKey: MaliciousSiteDataKey { switch key.dataType { case .filterSet: self.loadFilterSetCalled = true diff --git a/Tests/MaliciousSiteProtectionTests/Mocks/MockPhishingDetectionUpdateManager.swift b/Tests/MaliciousSiteProtectionTests/Mocks/MockPhishingDetectionUpdateManager.swift index 605b2118d..767f4c13a 100644 --- a/Tests/MaliciousSiteProtectionTests/Mocks/MockPhishingDetectionUpdateManager.swift +++ b/Tests/MaliciousSiteProtectionTests/Mocks/MockPhishingDetectionUpdateManager.swift @@ -20,7 +20,7 @@ import Foundation @testable import MaliciousSiteProtection class MockPhishingDetectionUpdateManager: MaliciousSiteProtection.UpdateManaging { - + var didUpdateFilterSet = false var didUpdateHashPrefixes = false var completionHandler: (() -> Void)? From 93ad38ef58fa98d4ac447eab99ab1926a03add41 Mon Sep 17 00:00:00 2001 From: Alexey Martemyanov Date: Thu, 28 Nov 2024 23:03:40 +0500 Subject: [PATCH 15/15] Malware protection 4: Refactor Malicious update manager (#1094) Please review the release process for BrowserServicesKit [here](https://app.asana.com/0/1200194497630846/1200837094583426). **Required**: Task/Issue URL: https://app.asana.com/0/1202406491309510/1208033567421351/f Tech design: https://app.asana.com/0/481882893211075/1208736595187321/f iOS PR: macOS PR: What kind of version bump will this require?: Major **Optional**: Tech Design URL: CC: **Description**: - Reduce code duplication by using `Task.periodic` instead of `BackgroundActivities` - Improve testability of the `Task.periodic` runner (`SwiftClocks` package) **Steps to test this PR**: 1. Validate Malicious Site Protection background updates work **OS Testing**: * [ ] iOS 14 * [ ] iOS 15 * [ ] iOS 16 * [ ] macOS 10.15 * [ ] macOS 11 * [ ] macOS 12 --- ###### Internal references: [Software Engineering Expectations](https://app.asana.com/0/59792373528535/199064865822552) [Technical Design Template](https://app.asana.com/0/59792373528535/184709971311943) --- Package.resolved | 27 +++ Package.swift | 4 +- .../Common/Concurrency/TaskExtension.swift | 81 ++++--- .../Logger+MaliciousSiteProtection.swift | 1 - .../Model/IncrementallyUpdatableDataSet.swift | 2 +- .../PhishingDetectionDataActivities.swift | 110 --------- .../Services/EmbeddedDataProvider.swift | 6 +- .../Services/UpdateManager.swift | 39 ++- .../BackgroundActivitySchedulerTests.swift | 57 ----- ...iousSiteProtectionUpdateManagerTests.swift | 226 ++++++++++++++++-- .../BackgroundActivitySchedulerMock.swift | 34 --- ...MockMaliciousSiteProtectionAPIClient.swift | 8 +- ...ckMaliciousSiteProtectionDataManager.swift | 15 +- ...usSiteProtectionEmbeddedDataProvider.swift | 56 ++++- .../MockPhishingDetectionUpdateManager.swift | 5 + ...PhishingDetectionDataActivitiesTests.swift | 48 ---- 16 files changed, 386 insertions(+), 333 deletions(-) delete mode 100644 Sources/MaliciousSiteProtection/PhishingDetectionDataActivities.swift delete mode 100644 Tests/MaliciousSiteProtectionTests/BackgroundActivitySchedulerTests.swift delete mode 100644 Tests/MaliciousSiteProtectionTests/Mocks/BackgroundActivitySchedulerMock.swift delete mode 100644 Tests/MaliciousSiteProtectionTests/PhishingDetectionDataActivitiesTests.swift diff --git a/Package.resolved b/Package.resolved index 4461ccf24..2b5b815cc 100644 --- a/Package.resolved +++ b/Package.resolved @@ -63,6 +63,24 @@ "version" : "3.0.0" } }, + { + "identity" : "swift-clocks", + "kind" : "remoteSourceControl", + "location" : "https://github.com/pointfreeco/swift-clocks.git", + "state" : { + "revision" : "b9b24b69e2adda099a1fa381cda1eeec272d5b53", + "version" : "1.0.5" + } + }, + { + "identity" : "swift-concurrency-extras", + "kind" : "remoteSourceControl", + "location" : "https://github.com/pointfreeco/swift-concurrency-extras", + "state" : { + "revision" : "163409ef7dae9d960b87f34b51587b6609a76c1f", + "version" : "1.3.0" + } + }, { "identity" : "swifter", "kind" : "remoteSourceControl", @@ -89,6 +107,15 @@ "revision" : "5de0a610a7927b638a5fd463a53032c9934a2c3b", "version" : "3.0.0" } + }, + { + "identity" : "xctest-dynamic-overlay", + "kind" : "remoteSourceControl", + "location" : "https://github.com/pointfreeco/xctest-dynamic-overlay", + "state" : { + "revision" : "a3f634d1a409c7979cabc0a71b3f26ffa9fc8af1", + "version" : "1.4.3" + } } ], "version" : 2 diff --git a/Package.swift b/Package.swift index 1841d6dbe..d8b3e2d0c 100644 --- a/Package.swift +++ b/Package.swift @@ -57,7 +57,8 @@ let package = Package( .package(url: "https://github.com/duckduckgo/privacy-dashboard", exact: "7.2.0"), .package(url: "https://github.com/httpswift/swifter.git", exact: "1.5.0"), .package(url: "https://github.com/duckduckgo/bloom_cpp.git", exact: "3.0.0"), - .package(url: "https://github.com/1024jp/GzipSwift.git", exact: "6.0.1") + .package(url: "https://github.com/1024jp/GzipSwift.git", exact: "6.0.1"), + .package(url: "https://github.com/pointfreeco/swift-clocks.git", exact: "1.0.5"), ], targets: [ .target( @@ -654,6 +655,7 @@ let package = Package( dependencies: [ "TestUtils", "MaliciousSiteProtection", + .product(name: "Clocks", package: "swift-clocks"), ], resources: [ .copy("Resources/phishingHashPrefixes.json"), diff --git a/Sources/Common/Concurrency/TaskExtension.swift b/Sources/Common/Concurrency/TaskExtension.swift index a407e4406..65d974a36 100644 --- a/Sources/Common/Concurrency/TaskExtension.swift +++ b/Sources/Common/Concurrency/TaskExtension.swift @@ -18,51 +18,72 @@ import Foundation +public struct Sleeper { + + public static let `default` = Sleeper(sleep: { + try await Task.sleep(interval: $0) + }) + + private let sleep: (TimeInterval) async throws -> Void + + public init(sleep: @escaping (TimeInterval) async throws -> Void) { + self.sleep = sleep + } + + @available(macOS 13.0, iOS 16.0, *) + public init(clock: any Clock) { + self.sleep = { interval in + try await clock.sleep(for: .nanoseconds(UInt64(interval * Double(NSEC_PER_SEC)))) + } + } + + public func sleep(for interval: TimeInterval) async throws { + try await sleep(interval) + } + +} + +public func performPeriodicJob(withDelay delay: TimeInterval? = nil, + interval: TimeInterval, + sleeper: Sleeper = .default, + operation: @escaping @Sendable () async throws -> Void, + cancellationHandler: (@Sendable () async -> Void)? = nil) async throws -> Never { + + do { + if let delay { + try await sleeper.sleep(for: delay) + } + + repeat { + try await operation() + + try await sleeper.sleep(for: interval) + } while true + } catch let error as CancellationError { + await cancellationHandler?() + throw error + } +} + public extension Task where Success == Never, Failure == Error { static func periodic(delay: TimeInterval? = nil, interval: TimeInterval, + sleeper: Sleeper = .default, operation: @escaping @Sendable () async -> Void, cancellationHandler: (@Sendable () async -> Void)? = nil) -> Task { - Task { - do { - if let delay { - try await Task.sleep(interval: delay) - } - - repeat { - await operation() - - try await Task.sleep(interval: interval) - } while true - } catch { - await cancellationHandler?() - throw error - } - } + return periodic(delay: delay, interval: interval, sleeper: sleeper, operation: { await operation() } as @Sendable () async throws -> Void, cancellationHandler: cancellationHandler) } static func periodic(delay: TimeInterval? = nil, interval: TimeInterval, + sleeper: Sleeper = .default, operation: @escaping @Sendable () async throws -> Void, cancellationHandler: (@Sendable () async -> Void)? = nil) -> Task { Task { - do { - if let delay { - try await Task.sleep(interval: delay) - } - - repeat { - try await operation() - - try await Task.sleep(interval: interval) - } while true - } catch { - await cancellationHandler?() - throw error - } + try await performPeriodicJob(withDelay: delay, interval: interval, sleeper: sleeper, operation: operation, cancellationHandler: cancellationHandler) } } } diff --git a/Sources/MaliciousSiteProtection/Logger+MaliciousSiteProtection.swift b/Sources/MaliciousSiteProtection/Logger+MaliciousSiteProtection.swift index 15b47d8e1..3e44f3bcd 100644 --- a/Sources/MaliciousSiteProtection/Logger+MaliciousSiteProtection.swift +++ b/Sources/MaliciousSiteProtection/Logger+MaliciousSiteProtection.swift @@ -25,7 +25,6 @@ public extension os.Logger { public static var api = os.Logger(subsystem: "MSP", category: "API") public static var dataManager = os.Logger(subsystem: "MSP", category: "DataManager") public static var updateManager = os.Logger(subsystem: "MSP", category: "UpdateManager") - static var phishingDetectionTasks = os.Logger(subsystem: "MSP", category: "BackgroundActivities") } } diff --git a/Sources/MaliciousSiteProtection/Model/IncrementallyUpdatableDataSet.swift b/Sources/MaliciousSiteProtection/Model/IncrementallyUpdatableDataSet.swift index 3d5a74b14..8a23785ae 100644 --- a/Sources/MaliciousSiteProtection/Model/IncrementallyUpdatableDataSet.swift +++ b/Sources/MaliciousSiteProtection/Model/IncrementallyUpdatableDataSet.swift @@ -16,7 +16,7 @@ // limitations under the License. // -protocol IncrementallyUpdatableDataSet: Codable { +protocol IncrementallyUpdatableDataSet: Codable, Equatable { /// Set Element Type (Hash Prefix or Filter) associatedtype Element: Codable, Hashable /// API Request type used to fetch updates for the data set diff --git a/Sources/MaliciousSiteProtection/PhishingDetectionDataActivities.swift b/Sources/MaliciousSiteProtection/PhishingDetectionDataActivities.swift deleted file mode 100644 index 0b811c5b8..000000000 --- a/Sources/MaliciousSiteProtection/PhishingDetectionDataActivities.swift +++ /dev/null @@ -1,110 +0,0 @@ -// -// PhishingDetectionDataActivities.swift -// -// Copyright © 2024 DuckDuckGo. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// - -import Foundation -import Common -import os - -public protocol BackgroundActivityScheduling: Actor { - func start() - func stop() -} -actor BackgroundActivityScheduler: BackgroundActivityScheduling { - - private var task: Task? - private var timer: Timer? - private let interval: TimeInterval - private let identifier: String - private let activity: () async -> Void - - init(interval: TimeInterval, identifier: String, activity: @escaping () async -> Void) { - self.interval = interval - self.identifier = identifier - self.activity = activity - } - - func start() { - stop() - task = Task { - let taskId = UUID().uuidString - while !Task.isCancelled { - await activity() - do { - Logger.phishingDetectionTasks.debug("🟢 \(self.identifier) task was executed in instance \(taskId)") - try await Task.sleep(nanoseconds: UInt64(interval * 1_000_000_000)) - } catch { - Logger.phishingDetectionTasks.error("🔴 Error \(self.identifier) task was cancelled before it could finish sleeping.") - break - } - } - } - } - - func stop() { - task?.cancel() - task = nil - } -} - -public protocol PhishingDetectionDataActivityHandling { - func start() - func stop() -} - -public class PhishingDetectionDataActivities: PhishingDetectionDataActivityHandling { - private var schedulers: [BackgroundActivityScheduler] - private var running: Bool = false - - public convenience init(hashPrefixInterval: TimeInterval = 20 * 60, filterSetInterval: TimeInterval = 12 * 60 * 60, updateManager: UpdateManager) { - self.init(hashPrefixInterval: hashPrefixInterval, filterSetInterval: filterSetInterval, updateManager: updateManager as UpdateManaging) - } - - init(hashPrefixInterval: TimeInterval = 20 * 60, filterSetInterval: TimeInterval = 12 * 60 * 60, updateManager: UpdateManaging) { - let hashPrefixScheduler = BackgroundActivityScheduler( - interval: hashPrefixInterval, - identifier: "hashPrefixes.update", - activity: { await updateManager.updateHashPrefixes() } - ) - let filterSetScheduler = BackgroundActivityScheduler( - interval: filterSetInterval, - identifier: "filterSet.update", - activity: { await updateManager.updateFilterSet() } - ) - self.schedulers = [hashPrefixScheduler, filterSetScheduler] - } - - public func start() { - if !running { - Task { - for scheduler in schedulers { - await scheduler.start() - } - } - running = true - } - } - - public func stop() { - Task { - for scheduler in schedulers { - await scheduler.stop() - } - } - running = false - } -} diff --git a/Sources/MaliciousSiteProtection/Services/EmbeddedDataProvider.swift b/Sources/MaliciousSiteProtection/Services/EmbeddedDataProvider.swift index 9352ac250..c9c82c2a0 100644 --- a/Sources/MaliciousSiteProtection/Services/EmbeddedDataProvider.swift +++ b/Sources/MaliciousSiteProtection/Services/EmbeddedDataProvider.swift @@ -16,8 +16,8 @@ // limitations under the License. // -import Foundation import CryptoKit +import Foundation public protocol EmbeddedDataProviding { func revision(for dataType: DataManager.StoredDataType) -> Int @@ -39,13 +39,13 @@ extension EmbeddedDataProviding { assert(data.sha256 == hash(for: dataType), "SHA mismatch for \(url.path)") #endif } catch { - fatalError("Could not load embedded data set at \(url.path): \(error)") + fatalError("\(self): Could not load embedded data set at “\(url)”: \(error)") } do { let result = try JSONDecoder().decode(DataKey.EmbeddedDataSet.self, from: data) return result } catch { - fatalError("Could not decode embedded data set at \(url.path): \(error)") + fatalError("\(self): Could not decode embedded data set at “\(url)”: \(error)") } } diff --git a/Sources/MaliciousSiteProtection/Services/UpdateManager.swift b/Sources/MaliciousSiteProtection/Services/UpdateManager.swift index 021725d7e..2b91ac051 100644 --- a/Sources/MaliciousSiteProtection/Services/UpdateManager.swift +++ b/Sources/MaliciousSiteProtection/Services/UpdateManager.swift @@ -23,21 +23,27 @@ import os protocol UpdateManaging { func updateData(for key: some MaliciousSiteDataKey) async - func updateFilterSet() async - func updateHashPrefixes() async + func startPeriodicUpdates() -> Task } public struct UpdateManager: UpdateManaging { + private let apiClient: APIClient.Mockable private let dataManager: DataManaging - public init(apiEnvironment: APIClientEnvironment, dataManager: DataManager) { - self.init(apiClient: APIClient(environment: apiEnvironment), dataManager: dataManager) + public typealias UpdateIntervalProvider = (DataManager.StoredDataType) -> TimeInterval? + private let updateIntervalProvider: UpdateIntervalProvider + private let sleeper: Sleeper + + public init(apiEnvironment: APIClientEnvironment, dataManager: DataManager, updateIntervalProvider: @escaping UpdateIntervalProvider) { + self.init(apiClient: APIClient(environment: apiEnvironment), dataManager: dataManager, updateIntervalProvider: updateIntervalProvider) } - init(apiClient: APIClient.Mockable, dataManager: DataManaging) { + init(apiClient: APIClient.Mockable, dataManager: DataManaging, sleeper: Sleeper = .default, updateIntervalProvider: @escaping UpdateIntervalProvider) { self.apiClient = apiClient self.dataManager = dataManager + self.updateIntervalProvider = updateIntervalProvider + self.sleeper = sleeper } func updateData(for key: DataKey) async { @@ -67,12 +73,25 @@ public struct UpdateManager: UpdateManaging { Logger.updateManager.debug("\(type(of: key)).\(key.threatKind) updated from rev.\(oldRevision) to rev.\(dataSet.revision)") } - public func updateFilterSet() async { - await updateData(for: .filterSet(threatKind: .phishing)) - } + public func startPeriodicUpdates() -> Task { + Task.detached { + // run update jobs in background for every data type + try await withThrowingTaskGroup(of: Never.self) { group in + for dataType in DataManager.StoredDataType.allCases { + // get update interval from provider + guard let updateInterval = updateIntervalProvider(dataType) else { continue } + assert(updateInterval > 0) - public func updateHashPrefixes() async { - await updateData(for: .hashPrefixes(threatKind: .phishing)) + group.addTask { + // run periodically until the parent task is cancelled + try await performPeriodicJob(interval: updateInterval, sleeper: sleeper) { + await self.updateData(for: dataType.dataKey) + } + } + } + for try await _ in group {} + } + } } } diff --git a/Tests/MaliciousSiteProtectionTests/BackgroundActivitySchedulerTests.swift b/Tests/MaliciousSiteProtectionTests/BackgroundActivitySchedulerTests.swift deleted file mode 100644 index 0640fc16f..000000000 --- a/Tests/MaliciousSiteProtectionTests/BackgroundActivitySchedulerTests.swift +++ /dev/null @@ -1,57 +0,0 @@ -// -// BackgroundActivitySchedulerTests.swift -// -// Copyright © 2024 DuckDuckGo. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// -import Foundation -import XCTest -@testable import MaliciousSiteProtection - -class BackgroundActivitySchedulerTests: XCTestCase { - var scheduler: BackgroundActivityScheduler! - var activityWasRun = false - - override func tearDown() { - scheduler = nil - super.tearDown() - } - - func testStart() async throws { - let expectation = self.expectation(description: "Activity should run") - scheduler = BackgroundActivityScheduler(interval: 1, identifier: "test") { - if !self.activityWasRun { - self.activityWasRun = true - expectation.fulfill() - } - } - await scheduler.start() - await fulfillment(of: [expectation], timeout: 2) - XCTAssertTrue(activityWasRun) - } - - func testRepeats() async throws { - let expectation = self.expectation(description: "Activity should repeat") - var runCount = 0 - scheduler = BackgroundActivityScheduler(interval: 1, identifier: "test") { - runCount += 1 - if runCount == 2 { - expectation.fulfill() - } - } - await scheduler.start() - await fulfillment(of: [expectation], timeout: 3) - XCTAssertEqual(runCount, 2) - } -} diff --git a/Tests/MaliciousSiteProtectionTests/MaliciousSiteProtectionUpdateManagerTests.swift b/Tests/MaliciousSiteProtectionTests/MaliciousSiteProtectionUpdateManagerTests.swift index 3a190e18e..8d46d5cf7 100644 --- a/Tests/MaliciousSiteProtectionTests/MaliciousSiteProtectionUpdateManagerTests.swift +++ b/Tests/MaliciousSiteProtectionTests/MaliciousSiteProtectionUpdateManagerTests.swift @@ -16,30 +16,47 @@ // limitations under the License. // +import Clocks +import Common import Foundation import XCTest @testable import MaliciousSiteProtection class MaliciousSiteProtectionUpdateManagerTests: XCTestCase { + var updateManager: MaliciousSiteProtection.UpdateManager! - var dataManager: MaliciousSiteProtection.DataManaging! + var dataManager: MockMaliciousSiteProtectionDataManager! var apiClient: MaliciousSiteProtection.APIClient.Mockable! + var updateIntervalProvider: UpdateManager.UpdateIntervalProvider! + var clock: TestClock! + var willSleep: ((TimeInterval) -> Void)? + var updateTask: Task? override func setUp() async throws { apiClient = MockMaliciousSiteProtectionAPIClient() dataManager = MockMaliciousSiteProtectionDataManager() - updateManager = MaliciousSiteProtection.UpdateManager(apiClient: apiClient, dataManager: dataManager) + clock = TestClock() + + let clockSleeper = Sleeper(clock: clock) + let reportingSleeper = Sleeper { + self.willSleep?($0) + try await clockSleeper.sleep(for: $0) + } + + updateManager = MaliciousSiteProtection.UpdateManager(apiClient: apiClient, dataManager: dataManager, sleeper: reportingSleeper, updateIntervalProvider: { self.updateIntervalProvider($0) }) } - override func tearDown() { + override func tearDown() async throws { updateManager = nil dataManager = nil apiClient = nil + updateIntervalProvider = nil + updateTask?.cancel() } func testUpdateHashPrefixes() async { - await updateManager.updateHashPrefixes() + await updateManager.updateData(for: .hashPrefixes(threatKind: .phishing)) let dataSet = await dataManager.dataSet(for: .hashPrefixes(threatKind: .phishing)) XCTAssertEqual(dataSet, HashPrefixSet(revision: 1, items: [ "aa00bb11", @@ -51,7 +68,7 @@ class MaliciousSiteProtectionUpdateManagerTests: XCTestCase { } func testUpdateFilterSet() async { - await updateManager.updateFilterSet() + await updateManager.updateData(for: .filterSet(threatKind: .phishing)) let dataSet = await dataManager.dataSet(for: .filterSet(threatKind: .phishing)) XCTAssertEqual(dataSet, FilterDictionary(revision: 1, items: [ Filter(hash: "testhash1", regex: ".*example.*"), @@ -72,12 +89,12 @@ class MaliciousSiteProtectionUpdateManagerTests: XCTestCase { ] // revision 0 -> 1 - await updateManager.updateFilterSet() - await updateManager.updateHashPrefixes() + await updateManager.updateData(for: .filterSet(threatKind: .phishing)) + await updateManager.updateData(for: .hashPrefixes(threatKind: .phishing)) // revision 1 -> 2 - await updateManager.updateFilterSet() - await updateManager.updateHashPrefixes() + await updateManager.updateData(for: .filterSet(threatKind: .phishing)) + await updateManager.updateData(for: .hashPrefixes(threatKind: .phishing)) let hashPrefixes = await dataManager.dataSet(for: .hashPrefixes(threatKind: .phishing)) let filterSet = await dataManager.dataSet(for: .filterSet(threatKind: .phishing)) @@ -116,8 +133,8 @@ class MaliciousSiteProtectionUpdateManagerTests: XCTestCase { "a379a6f6" ]), for: .hashPrefixes(threatKind: .phishing)) - await updateManager.updateFilterSet() - await updateManager.updateHashPrefixes() + await updateManager.updateData(for: .filterSet(threatKind: .phishing)) + await updateManager.updateData(for: .hashPrefixes(threatKind: .phishing)) let hashPrefixes = await dataManager.dataSet(for: .hashPrefixes(threatKind: .phishing)) let filterSet = await dataManager.dataSet(for: .filterSet(threatKind: .phishing)) @@ -134,8 +151,8 @@ class MaliciousSiteProtectionUpdateManagerTests: XCTestCase { await dataManager.store(FilterDictionary(revision: 3, items: []), for: .filterSet(threatKind: .phishing)) await dataManager.store(HashPrefixSet(revision: 3, items: []), for: .hashPrefixes(threatKind: .phishing)) - await updateManager.updateFilterSet() - await updateManager.updateHashPrefixes() + await updateManager.updateData(for: .filterSet(threatKind: .phishing)) + await updateManager.updateData(for: .hashPrefixes(threatKind: .phishing)) let hashPrefixes = await dataManager.dataSet(for: .hashPrefixes(threatKind: .phishing)) let filterSet = await dataManager.dataSet(for: .filterSet(threatKind: .phishing)) @@ -156,8 +173,8 @@ class MaliciousSiteProtectionUpdateManagerTests: XCTestCase { await dataManager.store(FilterDictionary(revision: 4, items: []), for: .filterSet(threatKind: .phishing)) await dataManager.store(HashPrefixSet(revision: 4, items: []), for: .hashPrefixes(threatKind: .phishing)) - await updateManager.updateFilterSet() - await updateManager.updateHashPrefixes() + await updateManager.updateData(for: .filterSet(threatKind: .phishing)) + await updateManager.updateData(for: .hashPrefixes(threatKind: .phishing)) let hashPrefixes = await dataManager.dataSet(for: .hashPrefixes(threatKind: .phishing)) let filterSet = await dataManager.dataSet(for: .filterSet(threatKind: .phishing)) @@ -187,8 +204,8 @@ class MaliciousSiteProtectionUpdateManagerTests: XCTestCase { "bb00cc11" ]), for: .hashPrefixes(threatKind: .phishing)) - await updateManager.updateFilterSet() - await updateManager.updateHashPrefixes() + await updateManager.updateData(for: .filterSet(threatKind: .phishing)) + await updateManager.updateData(for: .hashPrefixes(threatKind: .phishing)) let hashPrefixes = await dataManager.dataSet(for: .hashPrefixes(threatKind: .phishing)) let filterSet = await dataManager.dataSet(for: .filterSet(threatKind: .phishing)) @@ -197,4 +214,179 @@ class MaliciousSiteProtectionUpdateManagerTests: XCTestCase { XCTAssertEqual(filterSet, FilterDictionary(revision: 6, items: expectedFilterSet), "Filter set should match the expected set after update.") } + func testWhenPeriodicUpdatesStart_dataSetsAreUpdated() async throws { + self.updateIntervalProvider = { _ in 1 } + + let eHashPrefixesUpdated = expectation(description: "Hash prefixes updated") + let c1 = await dataManager.publisher(for: .hashPrefixes(threatKind: .phishing)).dropFirst().sink { data in + eHashPrefixesUpdated.fulfill() + } + let eFilterSetUpdated = expectation(description: "Filter set updated") + let c2 = await dataManager.publisher(for: .filterSet(threatKind: .phishing)).dropFirst().sink { data in + eFilterSetUpdated.fulfill() + } + + updateTask = updateManager.startPeriodicUpdates() + await Task.megaYield(count: 10) + + // expect initial update run instantly + await fulfillment(of: [eHashPrefixesUpdated, eFilterSetUpdated], timeout: 1) + + withExtendedLifetime((c1, c2)) {} + } + + func testWhenPeriodicUpdatesAreEnabled_dataSetsAreUpdatedContinuously() async throws { + // Start periodic updates + self.updateIntervalProvider = { dataType in + switch dataType { + case .filterSet: return 2 + case .hashPrefixSet: return 1 + } + } + + let hashPrefixUpdateExpectations = [ + XCTestExpectation(description: "Hash prefixes rev.1 update received"), + XCTestExpectation(description: "Hash prefixes rev.2 update received"), + XCTestExpectation(description: "Hash prefixes rev.3 update received"), + ] + let filterSetUpdateExpectations = [ + XCTestExpectation(description: "Filter set rev.1 update received"), + XCTestExpectation(description: "Filter set rev.2 update received"), + XCTestExpectation(description: "Filter set rev.3 update received"), + ] + let hashPrefixSleepExpectations = [ + XCTestExpectation(description: "HP Will Sleep 1"), + XCTestExpectation(description: "HP Will Sleep 2"), + XCTestExpectation(description: "HP Will Sleep 3"), + ] + let filterSetSleepExpectations = [ + XCTestExpectation(description: "FS Will Sleep 1"), + XCTestExpectation(description: "FS Will Sleep 2"), + XCTestExpectation(description: "FS Will Sleep 3"), + ] + + let c1 = await dataManager.publisher(for: .hashPrefixes(threatKind: .phishing)).dropFirst().sink { data in + hashPrefixUpdateExpectations[data.revision - 1].fulfill() + } + let c2 = await dataManager.publisher(for: .filterSet(threatKind: .phishing)).dropFirst().sink { data in + filterSetUpdateExpectations[data.revision - 1].fulfill() + } + var hashPrefixSleepIndex = 0 + var filterSetSleepIndex = 0 + self.willSleep = { interval in + if interval == 1 { + hashPrefixSleepExpectations[safe: hashPrefixSleepIndex]?.fulfill() + hashPrefixSleepIndex += 1 + } else { + filterSetSleepExpectations[safe: filterSetSleepIndex]?.fulfill() + filterSetSleepIndex += 1 + } + } + + // expect initial hashPrefixes update run instantly + updateTask = updateManager.startPeriodicUpdates() + await fulfillment(of: [hashPrefixUpdateExpectations[0], hashPrefixSleepExpectations[0], filterSetUpdateExpectations[0], filterSetSleepExpectations[0]], timeout: 1) + + // Advance the clock by 1 seconds + await self.clock.advance(by: .seconds(1)) + // expect to receive v.2 update for hashPrefixes + await fulfillment(of: [hashPrefixUpdateExpectations[1], hashPrefixSleepExpectations[1]], timeout: 1) + + // Advance the clock by 1 seconds + await self.clock.advance(by: .seconds(1)) + // expect to receive v.3 update for hashPrefixes and v.2 update for filterSet + await fulfillment(of: [hashPrefixUpdateExpectations[2], hashPrefixSleepExpectations[2], filterSetUpdateExpectations[1], filterSetSleepExpectations[1]], timeout: 1) // + + // Advance the clock by 1 seconds + await self.clock.advance(by: .seconds(2)) + // expect to receive v.3 update for filterSet and no update for hashPrefixes (no v.3 updates in the mock) + await fulfillment(of: [filterSetUpdateExpectations[2], filterSetSleepExpectations[2]], timeout: 1) // + + withExtendedLifetime((c1, c2)) {} + } + + func testWhenPeriodicUpdatesAreDisabled_noDataSetsAreUpdated() async throws { + // Start periodic updates + self.updateIntervalProvider = { dataType in + switch dataType { + case .filterSet: return nil // Set update interval to nil for FilterSet + case .hashPrefixSet: return 1 + } + } + + let expectations = [ + XCTestExpectation(description: "Hash prefixes rev.1 update received"), + XCTestExpectation(description: "Hash prefixes rev.2 update received"), + XCTestExpectation(description: "Hash prefixes rev.3 update received"), + ] + let c1 = await dataManager.publisher(for: .hashPrefixes(threatKind: .phishing)).dropFirst().sink { data in + expectations[data.revision - 1].fulfill() + } + // data for FilterSet should not be updated + let c2 = await dataManager.publisher(for: .filterSet(threatKind: .phishing)).dropFirst().sink { data in + XCTFail("Unexpected filter set update received: \(data)") + } + // synchronize Task threads to advance the Test Clock when the updated Task is sleeping, + // otherwise we‘ll eventually advance the clock before the sleep and get hung. + var sleepIndex = 0 + let sleepExpectations = [ + XCTestExpectation(description: "Will Sleep 1"), + XCTestExpectation(description: "Will Sleep 2"), + XCTestExpectation(description: "Will Sleep 3"), + ] + self.willSleep = { _ in + sleepExpectations[sleepIndex].fulfill() + sleepIndex += 1 + } + + // expect initial hashPrefixes update run instantly + updateTask = updateManager.startPeriodicUpdates() + await fulfillment(of: [expectations[0], sleepExpectations[0]], timeout: 1) + + // Advance the clock by 1 seconds + await self.clock.advance(by: .seconds(1)) + // expect to receive v.2 update for hashPrefixes + await fulfillment(of: [expectations[1], sleepExpectations[1]], timeout: 1) + + // Advance the clock by 1 seconds + await self.clock.advance(by: .seconds(1)) + // expect to receive v.3 update for hashPrefixes + await fulfillment(of: [expectations[2], sleepExpectations[2]], timeout: 1) + + withExtendedLifetime((c1, c2)) {} + } + + func testWhenPeriodicUpdatesAreCancelled_noFurtherUpdatesReceived() async throws { + // Start periodic updates + self.updateIntervalProvider = { _ in 1 } + updateTask = updateManager.startPeriodicUpdates() + + // Wait for the initial update + try await withTimeout(1) { [self] in + for await _ in await dataManager.publisher(for: .filterSet(threatKind: .phishing)).first(where: { $0.revision == 1 }).values {} + for await _ in await dataManager.publisher(for: .filterSet(threatKind: .phishing)).first(where: { $0.revision == 1 }).values {} + } + + // Cancel the update task + updateTask!.cancel() + + // Reset expectations for further updates + let c = await dataManager.$store.dropFirst().sink { data in + XCTFail("Unexpected data update received: \(data)") + } + + // Advance the clock to check for further updates + await self.clock.advance(by: .seconds(2)) + await clock.run() + await Task.megaYield(count: 10) + + // Verify that the data sets have not been updated further + let hashPrefixes = await dataManager.dataSet(for: .hashPrefixes(threatKind: .phishing)) + let filterSet = await dataManager.dataSet(for: .filterSet(threatKind: .phishing)) + XCTAssertEqual(hashPrefixes.revision, 1) // Expecting revision to remain 1 + XCTAssertEqual(filterSet.revision, 1) // Expecting revision to remain 1 + + withExtendedLifetime(c) {} + } + } diff --git a/Tests/MaliciousSiteProtectionTests/Mocks/BackgroundActivitySchedulerMock.swift b/Tests/MaliciousSiteProtectionTests/Mocks/BackgroundActivitySchedulerMock.swift deleted file mode 100644 index 6f2f8a20a..000000000 --- a/Tests/MaliciousSiteProtectionTests/Mocks/BackgroundActivitySchedulerMock.swift +++ /dev/null @@ -1,34 +0,0 @@ -// -// BackgroundActivitySchedulerMock.swift -// -// Copyright © 2024 DuckDuckGo. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// - -import Foundation -import MaliciousSiteProtection -actor MockBackgroundActivityScheduler: BackgroundActivityScheduling { - var startCalled = false - var stopCalled = false - var interval: TimeInterval = 1 - var identifier: String = "test" - - func start() { - startCalled = true - } - - func stop() { - stopCalled = true - } -} diff --git a/Tests/MaliciousSiteProtectionTests/Mocks/MockMaliciousSiteProtectionAPIClient.swift b/Tests/MaliciousSiteProtectionTests/Mocks/MockMaliciousSiteProtectionAPIClient.swift index 1a580f5c8..4f2062edd 100644 --- a/Tests/MaliciousSiteProtectionTests/Mocks/MockMaliciousSiteProtectionAPIClient.swift +++ b/Tests/MaliciousSiteProtectionTests/Mocks/MockMaliciousSiteProtectionAPIClient.swift @@ -20,8 +20,8 @@ import Foundation @testable import MaliciousSiteProtection class MockMaliciousSiteProtectionAPIClient: MaliciousSiteProtection.APIClient.Mockable { - var updateHashPrefixesWasCalled: Bool = false - var updateFilterSetsWasCalled: Bool = false + var updateHashPrefixesCalled: ((Int) -> Void)? + var updateFilterSetsCalled: ((Int) -> Void)? var filterRevisions: [Int: APIClient.Response.FiltersChangeSet] = [ 0: .init(insert: [ @@ -84,12 +84,12 @@ class MockMaliciousSiteProtectionAPIClient: MaliciousSiteProtection.APIClient.Mo } } func _filtersChangeSet(for threatKind: MaliciousSiteProtection.ThreatKind, revision: Int) -> MaliciousSiteProtection.APIClient.Response.FiltersChangeSet { - updateFilterSetsWasCalled = true + updateFilterSetsCalled?(revision) return filterRevisions[revision] ?? .init(insert: [], delete: [], revision: revision, replace: false) } func _hashPrefixesChangeSet(for threatKind: MaliciousSiteProtection.ThreatKind, revision: Int) -> MaliciousSiteProtection.APIClient.Response.HashPrefixesChangeSet { - updateHashPrefixesWasCalled = true + updateHashPrefixesCalled?(revision) return hashPrefixRevisions[revision] ?? .init(insert: [], delete: [], revision: revision, replace: false) } diff --git a/Tests/MaliciousSiteProtectionTests/Mocks/MockMaliciousSiteProtectionDataManager.swift b/Tests/MaliciousSiteProtectionTests/Mocks/MockMaliciousSiteProtectionDataManager.swift index ea9a63a71..1a67ad329 100644 --- a/Tests/MaliciousSiteProtectionTests/Mocks/MockMaliciousSiteProtectionDataManager.swift +++ b/Tests/MaliciousSiteProtectionTests/Mocks/MockMaliciousSiteProtectionDataManager.swift @@ -16,14 +16,21 @@ // limitations under the License. // +import Combine import Foundation @testable import MaliciousSiteProtection -class MockMaliciousSiteProtectionDataManager: MaliciousSiteProtection.DataManaging { - var store = [MaliciousSiteProtection.DataManager.StoredDataType: Any]() +actor MockMaliciousSiteProtectionDataManager: MaliciousSiteProtection.DataManaging { - func dataSet(for key: DataKey) async -> DataKey.DataSet where DataKey: MaliciousSiteProtection.MaliciousSiteDataKey { - store[key.dataType] as? DataKey.DataSet ?? .init(revision: 0, items: []) + @Published var store = [MaliciousSiteProtection.DataManager.StoredDataType: Any]() + func publisher(for key: DataKey) -> AnyPublisher where DataKey: MaliciousSiteProtection.MaliciousSiteDataKey { + $store.map { $0[key.dataType] as? DataKey.DataSet ?? .init(revision: 0, items: []) } + .removeDuplicates() + .eraseToAnyPublisher() + } + + public func dataSet(for key: DataKey) -> DataKey.DataSet where DataKey: MaliciousSiteProtection.MaliciousSiteDataKey { + return store[key.dataType] as? DataKey.DataSet ?? .init(revision: 0, items: []) } func store(_ dataSet: DataKey.DataSet, for key: DataKey) async where DataKey: MaliciousSiteProtection.MaliciousSiteDataKey { diff --git a/Tests/MaliciousSiteProtectionTests/Mocks/MockMaliciousSiteProtectionEmbeddedDataProvider.swift b/Tests/MaliciousSiteProtectionTests/Mocks/MockMaliciousSiteProtectionEmbeddedDataProvider.swift index 0beb44dd7..37f6c2962 100644 --- a/Tests/MaliciousSiteProtectionTests/Mocks/MockMaliciousSiteProtectionEmbeddedDataProvider.swift +++ b/Tests/MaliciousSiteProtectionTests/Mocks/MockMaliciousSiteProtectionEmbeddedDataProvider.swift @@ -17,35 +17,65 @@ // import Foundation + @testable import MaliciousSiteProtection -class MockMaliciousSiteProtectionEmbeddedDataProvider: MaliciousSiteProtection.EmbeddedDataProviding { +final class MockMaliciousSiteProtectionEmbeddedDataProvider: MaliciousSiteProtection.EmbeddedDataProviding { var embeddedRevision: Int = 65 var loadHashPrefixesCalled: Bool = false var loadFilterSetCalled: Bool = true - var hashPrefixes = Set(["aabb"]) - var filterSet = Set([Filter(hash: "dummyhash", regex: "dummyregex")]) + var hashPrefixes: Set = [] { + didSet { + hashPrefixesData = try! JSONEncoder().encode(hashPrefixes) + } + } + var hashPrefixesData: Data! - func revision(for detectionKind: MaliciousSiteProtection.DataManager.StoredDataType) -> Int { - embeddedRevision + var filterSet: Set = [] { + didSet { + filterSetData = try! JSONEncoder().encode(filterSet) + } } + var filterSetData: Data! - func url(for detectionKind: MaliciousSiteProtection.DataManager.StoredDataType) -> URL { - URL.empty + init() { + hashPrefixes = Set(["aabb"]) + filterSet = Set([Filter(hash: "dummyhash", regex: "dummyregex")]) } - func hash(for detectionKind: MaliciousSiteProtection.DataManager.StoredDataType) -> String { - "" + func revision(for detectionKind: MaliciousSiteProtection.DataManager.StoredDataType) -> Int { + embeddedRevision } - func loadDataSet(for key: DataKey) -> DataKey.EmbeddedDataSet where DataKey: MaliciousSiteDataKey { - switch key.dataType { + func url(for dataType: MaliciousSiteProtection.DataManager.StoredDataType) -> URL { + switch dataType { case .filterSet: self.loadFilterSetCalled = true - return Array(filterSet) as! DataKey.EmbeddedDataSet + return URL(string: "filterSet")! case .hashPrefixSet: self.loadHashPrefixesCalled = true - return Array(hashPrefixes) as! DataKey.EmbeddedDataSet + return URL(string: "hashPrefixSet")! + } + } + + func hash(for dataType: MaliciousSiteProtection.DataManager.StoredDataType) -> String { + let url = url(for: dataType) + let data = try! data(withContentsOf: url) + let sha = data.sha256 + return sha + } + + func data(withContentsOf url: URL) throws -> Data { + let data: Data + switch url.absoluteString { + case "filterSet": + self.loadFilterSetCalled = true + return filterSetData + case "hashPrefixSet": + self.loadHashPrefixesCalled = true + return hashPrefixesData + default: + fatalError("Unexpected url \(url.absoluteString)") } } diff --git a/Tests/MaliciousSiteProtectionTests/Mocks/MockPhishingDetectionUpdateManager.swift b/Tests/MaliciousSiteProtectionTests/Mocks/MockPhishingDetectionUpdateManager.swift index 767f4c13a..b49eac588 100644 --- a/Tests/MaliciousSiteProtectionTests/Mocks/MockPhishingDetectionUpdateManager.swift +++ b/Tests/MaliciousSiteProtectionTests/Mocks/MockPhishingDetectionUpdateManager.swift @@ -23,6 +23,7 @@ class MockPhishingDetectionUpdateManager: MaliciousSiteProtection.UpdateManaging var didUpdateFilterSet = false var didUpdateHashPrefixes = false + var startPeriodicUpdatesCalled = false var completionHandler: (() -> Void)? func updateData(for key: some MaliciousSiteProtection.MaliciousSiteDataKey) async { @@ -48,4 +49,8 @@ class MockPhishingDetectionUpdateManager: MaliciousSiteProtection.UpdateManaging } } + public func startPeriodicUpdates() -> Task { + startPeriodicUpdatesCalled = true + return Task {} + } } diff --git a/Tests/MaliciousSiteProtectionTests/PhishingDetectionDataActivitiesTests.swift b/Tests/MaliciousSiteProtectionTests/PhishingDetectionDataActivitiesTests.swift deleted file mode 100644 index 3d3ad4a01..000000000 --- a/Tests/MaliciousSiteProtectionTests/PhishingDetectionDataActivitiesTests.swift +++ /dev/null @@ -1,48 +0,0 @@ -// -// PhishingDetectionDataActivitiesTests.swift -// -// Copyright © 2024 DuckDuckGo. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// - -import Foundation -import XCTest -@testable import MaliciousSiteProtection - -class PhishingDetectionDataActivitiesTests: XCTestCase { - var mockUpdateManager: MockPhishingDetectionUpdateManager! - var activities: PhishingDetectionDataActivities! - - override func setUp() { - super.setUp() - mockUpdateManager = MockPhishingDetectionUpdateManager() - activities = PhishingDetectionDataActivities(hashPrefixInterval: 1, filterSetInterval: 1, updateManager: mockUpdateManager) - } - - func testUpdateHashPrefixesAndFilterSetRuns() async { - let expectation = XCTestExpectation(description: "updateHashPrefixes and updateFilterSet completes") - - mockUpdateManager.completionHandler = { - expectation.fulfill() - } - - activities.start() - - await fulfillment(of: [expectation], timeout: 10.0) - - XCTAssertTrue(mockUpdateManager.didUpdateHashPrefixes) - XCTAssertTrue(mockUpdateManager.didUpdateFilterSet) - - } -}