From 50ed41cdbaabdc91dcc4ab65ddd9070909dcd960 Mon Sep 17 00:00:00 2001 From: Alexey Martemyanov Date: Mon, 25 Nov 2024 21:45:35 +0600 Subject: [PATCH 1/6] refactor APIClient --- Package.swift | 1 + Sources/Common/Extensions/URLExtension.swift | 25 +-- .../API/APIClient.swift | 172 +++++++++--------- .../API/APIRequest.swift | 84 +++++++++ .../API/ChangeSetResponse.swift | 7 +- .../API/MatchResponse.swift | 4 + .../MaliciousSiteDetector.swift | 12 +- .../Services/UpdateManager.swift | 32 +++- Sources/Networking/README.md | 12 +- Sources/Networking/v2/APIRequestV2.swift | 48 +++-- Sources/Networking/v2/APIResponseV2.swift | 9 +- .../Extensions/Dictionary+URLQueryItem.swift | 35 ---- Sources/TestUtils/MockAPIService.swift | 12 +- ...aliciousSiteProtectionAPIClientTests.swift | 120 +++++++----- .../Mocks/MockMaliciousSiteDetector.swift | 38 ---- ...MockMaliciousSiteProtectionAPIClient.swift | 25 ++- .../v2/APIRequestV2Tests.swift | 41 ++--- .../NetworkingTests/v2/APIServiceTests.swift | 25 ++- 18 files changed, 387 insertions(+), 315 deletions(-) create mode 100644 Sources/MaliciousSiteProtection/API/APIRequest.swift delete mode 100644 Sources/Networking/v2/Extensions/Dictionary+URLQueryItem.swift delete mode 100644 Tests/MaliciousSiteProtectionTests/Mocks/MockMaliciousSiteDetector.swift diff --git a/Package.swift b/Package.swift index eadd9b6d1..c074ced7e 100644 --- a/Package.swift +++ b/Package.swift @@ -651,6 +651,7 @@ let package = Package( .testTarget( name: "MaliciousSiteProtectionTests", dependencies: [ + "TestUtils", "MaliciousSiteProtection", ], resources: [ diff --git a/Sources/Common/Extensions/URLExtension.swift b/Sources/Common/Extensions/URLExtension.swift index d19751148..78cc2881f 100644 --- a/Sources/Common/Extensions/URLExtension.swift +++ b/Sources/Common/Extensions/URLExtension.swift @@ -354,22 +354,24 @@ extension URL { // MARK: - Parameters + @_disfavoredOverload // prefer ordered KeyValuePairs collection when `parameters` passed as a Dictionary literal to preserve order. public func appendingParameters(_ parameters: QueryParams, allowedReservedCharacters: CharacterSet? = nil) -> URL where QueryParams.Element == (key: String, value: String) { + let result = self.appending(percentEncodedQueryItems: parameters.map { name, value in + URLQueryItem(percentEncodingName: name, value: value, withAllowedCharacters: allowedReservedCharacters) + }) + return result + } - return parameters.reduce(self) { partialResult, parameter in - partialResult.appendingParameter( - name: parameter.key, - value: parameter.value, - allowedReservedCharacters: allowedReservedCharacters - ) - } + public func appendingParameters(_ parameters: KeyValuePairs, allowedReservedCharacters: CharacterSet? = nil) -> URL { + let result = self.appending(percentEncodedQueryItems: parameters.map { name, value in + URLQueryItem(percentEncodingName: name, value: value, withAllowedCharacters: allowedReservedCharacters) + }) + return result } public func appendingParameter(name: String, value: String, allowedReservedCharacters: CharacterSet? = nil) -> URL { - let queryItem = URLQueryItem(percentEncodingName: name, - value: value, - withAllowedCharacters: allowedReservedCharacters) + let queryItem = URLQueryItem(percentEncodingName: name, value: value, withAllowedCharacters: allowedReservedCharacters) return self.appending(percentEncodedQueryItem: queryItem) } @@ -383,8 +385,9 @@ extension URL { var existingPercentEncodedQueryItems = components.percentEncodedQueryItems ?? [URLQueryItem]() existingPercentEncodedQueryItems.append(contentsOf: percentEncodedQueryItems) components.percentEncodedQueryItems = existingPercentEncodedQueryItems + let result = components.url ?? self - return components.url ?? self + return result } public func getQueryItems() -> [URLQueryItem]? { diff --git a/Sources/MaliciousSiteProtection/API/APIClient.swift b/Sources/MaliciousSiteProtection/API/APIClient.swift index 50ddfbd6a..678171839 100644 --- a/Sources/MaliciousSiteProtection/API/APIClient.swift +++ b/Sources/MaliciousSiteProtection/API/APIClient.swift @@ -18,129 +18,119 @@ import Common import Foundation -import os import Networking public protocol APIClientProtocol { - func getFilterSet(revision: Int) async -> APIClient.FiltersChangeSetResponse - func getHashPrefixes(revision: Int) async -> APIClient.HashPrefixesChangeSetResponse - func getMatches(hashPrefix: String) async -> [Match] + func load(_ requestConfig: Request) async throws -> Request.ResponseType } -public protocol URLSessionProtocol { - func data(for request: URLRequest) async throws -> (Data, URLResponse) +public extension APIClientProtocol where Self == APIClient { + static var production: APIClientProtocol { APIClient(environment: .production) } + static var staging: APIClientProtocol { APIClient(environment: .staging) } } -extension URLSession: URLSessionProtocol {} - -extension URLSessionProtocol { - public static var defaultSession: URLSessionProtocol { - return URLSession.shared - } +public protocol APIClientEnvironment { + func headers(for request: APIClient.Request) -> APIRequestV2.HeadersV2 + func url(for request: APIClient.Request) -> URL } -public struct APIClient: APIClientProtocol { +public extension APIClient { + enum DefaultEnvironment: APIClientEnvironment { - public enum Environment { case production case staging - } + case dev - enum Constants { - static let productionEndpoint = URL(string: "https://duckduckgo.com/api/protection/")! - static let stagingEndpoint = URL(string: "https://staging.duckduckgo.com/api/protection/")! - enum APIPath: String { - case filterSet - case hashPrefix - case matches + var endpoint: URL { + switch self { + case .production: URL(string: "https://duckduckgo.com/api/protection/")! + case .staging: URL(string: "https://staging.duckduckgo.com/api/protection/")! + case .dev: URL(string: "https://4842-20-93-28-24.ngrok-free.app/api/protection/")! + } } - } - private let endpointURL: URL - private let session: URLSessionProtocol! - private var headers: [String: String]? = [:] + var defaultHeaders: APIRequestV2.HeadersV2 { + .init(userAgent: APIRequest.Headers.userAgent) + } - var filterSetURL: URL { - endpointURL.appendingPathComponent(Constants.APIPath.filterSet.rawValue) - } + enum APIPath { + static let filterSet = "filterSet" + static let hashPrefix = "hashPrefix" + static let matches = "matches" + } - var hashPrefixURL: URL { - endpointURL.appendingPathComponent(Constants.APIPath.hashPrefix.rawValue) - } + enum QueryParameter { + static let category = "category" + static let revision = "revision" + static let hashPrefix = "hashPrefix" + } - var matchesURL: URL { - endpointURL.appendingPathComponent(Constants.APIPath.matches.rawValue) - } + public func url(for request: APIClient.Request) -> URL { + switch request { + case .hashPrefixSet(let configuration): + endpoint.appendingPathComponent(APIPath.hashPrefix).appendingParameters([ + QueryParameter.category: configuration.threatKind.rawValue, + QueryParameter.revision: (configuration.revision ?? 0).description, + ]) + case .filterSet(let configuration): + endpoint.appendingPathComponent(APIPath.filterSet).appendingParameters([ + QueryParameter.category: configuration.threatKind.rawValue, + QueryParameter.revision: (configuration.revision ?? 0).description, + ]) + case .matches(let configuration): + endpoint.appendingPathComponent(APIPath.matches).appendingParameter(name: QueryParameter.hashPrefix, value: configuration.hashPrefix) + } + } - public init(environment: Environment = .production, session: URLSessionProtocol = URLSession.defaultSession) { - switch environment { - case .production: - endpointURL = Constants.productionEndpoint - case .staging: - endpointURL = Constants.stagingEndpoint + public func headers(for request: APIClient.Request) -> APIRequestV2.HeadersV2 { + defaultHeaders } - self.session = session } - public func getFilterSet(revision: Int) async -> FiltersChangeSetResponse { - guard let url = createURL(for: .filterSet, revision: revision) else { - logDebug("🔸 Invalid filterSet revision URL: \(revision)") - return FiltersChangeSetResponse(insert: [], delete: [], revision: revision, replace: false) - } - return await fetch(url: url, responseType: FiltersChangeSetResponse.self) ?? FiltersChangeSetResponse(insert: [], delete: [], revision: revision, replace: false) +} + +public struct APIClient: APIClientProtocol { + + let environment: APIClientEnvironment + private let service: APIService + + public init(environment: Self.DefaultEnvironment = .production, service: APIService = DefaultAPIService(urlSession: .shared)) { + self.init(environment: environment as APIClientEnvironment, service: service) } - public func getHashPrefixes(revision: Int) async -> HashPrefixesChangeSetResponse { - guard let url = createURL(for: .hashPrefix, revision: revision) else { - logDebug("🔸 Invalid hashPrefix revision URL: \(revision)") - return HashPrefixesChangeSetResponse(insert: [], delete: [], revision: revision, replace: false) - } - return await fetch(url: url, responseType: HashPrefixesChangeSetResponse.self) ?? HashPrefixesChangeSetResponse(insert: [], delete: [], revision: revision, replace: false) + public init(environment: APIClientEnvironment, service: APIService) { + self.environment = environment + self.service = service } - public func getMatches(hashPrefix: String) async -> [Match] { - let queryItems = [URLQueryItem(name: "hashPrefix", value: hashPrefix)] - guard let url = createURL(for: .matches, queryItems: queryItems) else { - logDebug("🔸 Invalid matches URL: \(hashPrefix)") - return [] - } - return await fetch(url: url, responseType: MatchResponse.self)?.matches ?? [] + public func load(_ requestConfig: Request) async throws -> Request.ResponseType { + let requestType = requestConfig.requestType + let headers = environment.headers(for: requestType) + let url = environment.url(for: requestType) + + let apiRequest = APIRequestV2(url: url, method: .get, headers: headers) + let response = try await service.fetch(request: apiRequest) + let result: Request.ResponseType = try response.decodeBody() + + return result } -} -// MARK: Private Methods -extension APIClient { +} - private func logDebug(_ message: String) { - Logger.api.debug("\(message)") +// MARK: - Convenience +extension APIClientProtocol { + public func filtersChangeSet(for threatKind: ThreatKind, revision: Int) async throws -> APIClient.Response.FiltersChangeSet { + let result = try await load(.filterSet(threatKind: threatKind, revision: revision)) + return result } - private func createURL(for path: Constants.APIPath, revision: Int? = nil, queryItems: [URLQueryItem]? = nil) -> URL? { - // Start with the base URL and append the path component - var urlComponents = URLComponents(url: endpointURL.appendingPathComponent(path.rawValue), resolvingAgainstBaseURL: true) - var items = queryItems ?? [] - if let revision = revision, revision > 0 { - items.append(URLQueryItem(name: "revision", value: String(revision))) - } - urlComponents?.queryItems = items.isEmpty ? nil : items - return urlComponents?.url + public func hashPrefixesChangeSet(for threatKind: ThreatKind, revision: Int) async throws -> APIClient.Response.HashPrefixesChangeSet { + let result = try await load(.hashPrefixes(threatKind: threatKind, revision: revision)) + return result } - private func fetch(url: URL, responseType: T.Type) async -> T? { - var request = URLRequest(url: url) - request.httpMethod = "GET" - request.allHTTPHeaderFields = headers - - do { - let (data, _) = try await session.data(for: request) - if let response = try? JSONDecoder().decode(responseType, from: data) { - return response - } else { - logDebug("🔸 Failed to decode response for \(String(describing: responseType)): \(data)") - } - } catch { - logDebug("🔴 Failed to load \(String(describing: responseType)) data: \(error)") - } - return nil + public func matches(forHashPrefix hashPrefix: String) async throws -> APIClient.Response.Matches { + let result = try await load(.matches(hashPrefix: hashPrefix)) + return result } } diff --git a/Sources/MaliciousSiteProtection/API/APIRequest.swift b/Sources/MaliciousSiteProtection/API/APIRequest.swift new file mode 100644 index 000000000..af18fb2f2 --- /dev/null +++ b/Sources/MaliciousSiteProtection/API/APIRequest.swift @@ -0,0 +1,84 @@ +// +// APIRequest.swift +// +// Copyright © 2024 DuckDuckGo. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +import Foundation + +public protocol APIRequestProtocol { + associatedtype ResponseType: Decodable + var requestType: APIClient.Request { get } +} + +public extension APIClient { + enum Request { + case hashPrefixSet(HashPrefixes) + case filterSet(FilterSet) + case matches(Matches) + } +} +public extension APIClient.Request { + struct HashPrefixes: APIRequestProtocol { + public typealias ResponseType = APIClient.Response.HashPrefixesChangeSet + + public let threatKind: ThreatKind + public let revision: Int? + + public var requestType: APIClient.Request { + .hashPrefixSet(self) + } + } +} +extension APIRequestProtocol where Self == APIClient.Request.HashPrefixes { + static func hashPrefixes(threatKind: ThreatKind, revision: Int?) -> Self { + .init(threatKind: threatKind, revision: revision) + } +} + +public extension APIClient.Request { + struct FilterSet: APIRequestProtocol { + public typealias ResponseType = APIClient.Response.FiltersChangeSet + + public let threatKind: ThreatKind + public let revision: Int? + + public var requestType: APIClient.Request { + .filterSet(self) + } + } +} +extension APIRequestProtocol where Self == APIClient.Request.FilterSet { + static func filterSet(threatKind: ThreatKind, revision: Int?) -> Self { + .init(threatKind: threatKind, revision: revision) + } +} + +public extension APIClient.Request { + struct Matches: APIRequestProtocol { + public typealias ResponseType = APIClient.Response.Matches + + public let hashPrefix: String + + public var requestType: APIClient.Request { + .matches(self) + } + } +} +extension APIRequestProtocol where Self == APIClient.Request.Matches { + static func matches(hashPrefix: String) -> Self { + .init(hashPrefix: hashPrefix) + } +} diff --git a/Sources/MaliciousSiteProtection/API/ChangeSetResponse.swift b/Sources/MaliciousSiteProtection/API/ChangeSetResponse.swift index 732411895..7988c2a32 100644 --- a/Sources/MaliciousSiteProtection/API/ChangeSetResponse.swift +++ b/Sources/MaliciousSiteProtection/API/ChangeSetResponse.swift @@ -34,7 +34,10 @@ extension APIClient { } } - public typealias FiltersChangeSetResponse = ChangeSetResponse - public typealias HashPrefixesChangeSetResponse = ChangeSetResponse + public enum Response { + public typealias FiltersChangeSet = ChangeSetResponse + public typealias HashPrefixesChangeSet = ChangeSetResponse + public typealias Matches = MatchResponse + } } diff --git a/Sources/MaliciousSiteProtection/API/MatchResponse.swift b/Sources/MaliciousSiteProtection/API/MatchResponse.swift index aaa48b388..2cb6df962 100644 --- a/Sources/MaliciousSiteProtection/API/MatchResponse.swift +++ b/Sources/MaliciousSiteProtection/API/MatchResponse.swift @@ -20,6 +20,10 @@ extension APIClient { public struct MatchResponse: Codable, Equatable { public var matches: [Match] + + public init(matches: [Match]) { + self.matches = matches + } } } diff --git a/Sources/MaliciousSiteProtection/MaliciousSiteDetector.swift b/Sources/MaliciousSiteProtection/MaliciousSiteDetector.swift index 9ac4c01c2..592e7e852 100644 --- a/Sources/MaliciousSiteProtection/MaliciousSiteDetector.swift +++ b/Sources/MaliciousSiteProtection/MaliciousSiteDetector.swift @@ -41,10 +41,6 @@ public final class MaliciousSiteDetector: MaliciousSiteDetecting { self.eventMapping = eventMapping } - private func getMatches(hashPrefix: String) async -> Set { - return Set(await apiClient.getMatches(hashPrefix: hashPrefix)) - } - private func inFilterSet(hash: String) -> Set { return Set(dataManager.filterSet.filter { $0.hash == hash }) } @@ -65,7 +61,13 @@ public final class MaliciousSiteDetector: MaliciousSiteDetecting { } private func fetchMatches(hashPrefix: String) async -> [Match] { - return await apiClient.getMatches(hashPrefix: hashPrefix) + do { + let response = try await apiClient.matches(forHashPrefix: hashPrefix) + return response.matches + } catch { + Logger.api.error("Failed to fetch matches for hash prefix: \(hashPrefix): \(error.localizedDescription)") + return [] + } } private func checkLocalFilters(canonicalHost: String, canonicalUrl: URL) -> Bool { diff --git a/Sources/MaliciousSiteProtection/Services/UpdateManager.swift b/Sources/MaliciousSiteProtection/Services/UpdateManager.swift index 053acd230..af9f60e7d 100644 --- a/Sources/MaliciousSiteProtection/Services/UpdateManager.swift +++ b/Sources/MaliciousSiteProtection/Services/UpdateManager.swift @@ -54,30 +54,42 @@ public struct UpdateManager: UpdateManaging { } public func updateFilterSet() async { - let response = await apiClient.getFilterSet(revision: dataManager.currentRevision) + let changeSet: APIClient.Response.FiltersChangeSet + do { + changeSet = try await apiClient.filtersChangeSet(for: .phishing, revision: dataManager.currentRevision) + } catch { + Logger.updateManager.error("error fetching filter set: \(error)") + return + } updateSet( currentSet: dataManager.filterSet, - insert: response.insert, - delete: response.delete, - replace: response.replace + insert: changeSet.insert, + delete: changeSet.delete, + replace: changeSet.replace ) { newSet in self.dataManager.saveFilterSet(set: newSet) } - dataManager.saveRevision(response.revision) + dataManager.saveRevision(changeSet.revision) Logger.updateManager.debug("filterSet updated to revision \(self.dataManager.currentRevision)") } public func updateHashPrefixes() async { - let response = await apiClient.getHashPrefixes(revision: dataManager.currentRevision) + let changeSet: APIClient.Response.HashPrefixesChangeSet + do { + changeSet = try await apiClient.hashPrefixesChangeSet(for: .phishing, revision: dataManager.currentRevision) + } catch { + Logger.updateManager.error("error fetching hash prefixes: \(error)") + return + } updateSet( currentSet: dataManager.hashPrefixes, - insert: response.insert, - delete: response.delete, - replace: response.replace + insert: changeSet.insert, + delete: changeSet.delete, + replace: changeSet.replace ) { newSet in self.dataManager.saveHashPrefixes(set: newSet) } - dataManager.saveRevision(response.revision) + dataManager.saveRevision(changeSet.revision) Logger.updateManager.debug("hashPrefixes updated to revision \(self.dataManager.currentRevision)") } } diff --git a/Sources/Networking/README.md b/Sources/Networking/README.md index 83a2c5ce3..751ee63d8 100644 --- a/Sources/Networking/README.md +++ b/Sources/Networking/README.md @@ -19,7 +19,7 @@ let request = APIRequestV2(url: HTTPURLResponse.testUrl, responseConstraints: [.allowHTTPNotModified, .requireETagHeader, .requireUserAgent], - allowedQueryReservedCharacters: CharacterSet(charactersIn: ","))! + allowedQueryReservedCharacters: CharacterSet(charactersIn: ",")) let apiService = DefaultAPIService(urlSession: URLSession.shared) ``` @@ -55,12 +55,12 @@ The `MockPIService` implementing `APIService` can be found in `BSK/TestUtils` ``` let apiResponse = (Data(), HTTPURLResponse(url: HTTPURLResponse.testUrl, - statusCode: 200, - httpVersion: nil, - headerFields: nil)!) -let mockedAPIService = MockAPIService(decodableResponse: Result.failure(SomeError.testError), apiResponse: Result.success(apiResponse) ) + statusCode: 200, + httpVersion: nil, + headerFields: nil)!) +let mockedAPIService = MockAPIService(apiResponse: Result.success(apiResponse)) ``` ## v1 (Legacy) -Not to be used. All V1 public functions have been deprecated and maintained only for backward compatibility. \ No newline at end of file +Not to be used. All V1 public functions have been deprecated and maintained only for backward compatibility. diff --git a/Sources/Networking/v2/APIRequestV2.swift b/Sources/Networking/v2/APIRequestV2.swift index 07434de67..a61604861 100644 --- a/Sources/Networking/v2/APIRequestV2.swift +++ b/Sources/Networking/v2/APIRequestV2.swift @@ -16,12 +16,11 @@ // limitations under the License. // +import Common import Foundation public struct APIRequestV2: CustomDebugStringConvertible { - public typealias QueryItems = [String: String] - let timeoutInterval: TimeInterval let responseConstraints: [APIResponseConstraints]? public let urlRequest: URLRequest @@ -37,25 +36,25 @@ public struct APIRequestV2: CustomDebugStringConvertible { /// - cachePolicy: The request cache policy, default is `.useProtocolCachePolicy` /// - responseRequirements: The response requirements /// - allowedQueryReservedCharacters: The characters in this character set will not be URL encoded in the query parameters - public init?(url: URL, - method: HTTPRequestMethod = .get, - queryItems: QueryItems? = nil, - headers: APIRequestV2.HeadersV2? = APIRequestV2.HeadersV2(), - body: Data? = nil, - timeoutInterval: TimeInterval = 60.0, - cachePolicy: URLRequest.CachePolicy? = nil, - responseConstraints: [APIResponseConstraints]? = nil, - allowedQueryReservedCharacters: CharacterSet? = nil) { + public init( + url: URL, + method: HTTPRequestMethod = .get, + queryItems: QueryParams?, + headers: APIRequestV2.HeadersV2? = APIRequestV2.HeadersV2(), + body: Data? = nil, + timeoutInterval: TimeInterval = 60.0, + cachePolicy: URLRequest.CachePolicy? = nil, + responseConstraints: [APIResponseConstraints]? = nil, + allowedQueryReservedCharacters: CharacterSet? = nil + ) where QueryParams.Element == (key: String, value: String) { + self.timeoutInterval = timeoutInterval self.responseConstraints = responseConstraints - // Generate URL request - guard var urlComps = URLComponents(url: url, resolvingAgainstBaseURL: true) else { - return nil - } - urlComps.queryItems = queryItems?.toURLQueryItems(allowedReservedCharacters: allowedQueryReservedCharacters) - guard let finalURL = urlComps.url else { - return nil + let finalURL = if let queryItems { + url.appendingParameters(queryItems, allowedReservedCharacters: allowedQueryReservedCharacters) + } else { + url } var request = URLRequest(url: finalURL, timeoutInterval: timeoutInterval) request.allHTTPHeaderFields = headers?.httpHeaders @@ -67,6 +66,19 @@ public struct APIRequestV2: CustomDebugStringConvertible { self.urlRequest = request } + public init( + url: URL, + method: HTTPRequestMethod = .get, + headers: APIRequestV2.HeadersV2? = APIRequestV2.HeadersV2(), + body: Data? = nil, + timeoutInterval: TimeInterval = 60.0, + cachePolicy: URLRequest.CachePolicy? = nil, + responseConstraints: [APIResponseConstraints]? = nil, + allowedQueryReservedCharacters: CharacterSet? = nil + ) { + self.init(url: url, method: method, queryItems: [String: String]?.none, headers: headers, body: body, timeoutInterval: timeoutInterval, cachePolicy: cachePolicy, responseConstraints: responseConstraints, allowedQueryReservedCharacters: allowedQueryReservedCharacters) + } + public var debugDescription: String { """ APIRequestV2: diff --git a/Sources/Networking/v2/APIResponseV2.swift b/Sources/Networking/v2/APIResponseV2.swift index 1b178fd93..8987e377b 100644 --- a/Sources/Networking/v2/APIResponseV2.swift +++ b/Sources/Networking/v2/APIResponseV2.swift @@ -20,8 +20,13 @@ import Foundation import os.log public struct APIResponseV2 { - let data: Data? - let httpResponse: HTTPURLResponse + public let data: Data? + public let httpResponse: HTTPURLResponse + + public init(data: Data?, httpResponse: HTTPURLResponse) { + self.data = data + self.httpResponse = httpResponse + } } public extension APIResponseV2 { diff --git a/Sources/Networking/v2/Extensions/Dictionary+URLQueryItem.swift b/Sources/Networking/v2/Extensions/Dictionary+URLQueryItem.swift deleted file mode 100644 index 81a4648d6..000000000 --- a/Sources/Networking/v2/Extensions/Dictionary+URLQueryItem.swift +++ /dev/null @@ -1,35 +0,0 @@ -// -// Dictionary+URLQueryItem.swift -// -// Copyright © 2024 DuckDuckGo. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// - -import Foundation -import Common - -extension Dictionary where Key == String, Value == String { - - func toURLQueryItems(allowedReservedCharacters: CharacterSet? = nil) -> [URLQueryItem] { - return self.map { - if let allowedReservedCharacters { - URLQueryItem(percentEncodingName: $0.key, - value: $0.value, - withAllowedCharacters: allowedReservedCharacters) - } else { - URLQueryItem(name: $0.key, value: $0.value) - } - } - } -} diff --git a/Sources/TestUtils/MockAPIService.swift b/Sources/TestUtils/MockAPIService.swift index be4bf47a2..f4d35b4b6 100644 --- a/Sources/TestUtils/MockAPIService.swift +++ b/Sources/TestUtils/MockAPIService.swift @@ -19,12 +19,16 @@ import Foundation import Networking -public struct MockAPIService: APIService { +public class MockAPIService: APIService { - public var apiResponse: Result + public var requestHandler: ((APIRequestV2) -> Result)! - public func fetch(request: Networking.APIRequestV2) async throws -> APIResponseV2 { - switch apiResponse { + public init(requestHandler: ((APIRequestV2) -> Result)? = nil) { + self.requestHandler = requestHandler + } + + public func fetch(request: APIRequestV2) async throws -> APIResponseV2 { + switch requestHandler!(request) { case .success(let result): return result case .failure(let error): diff --git a/Tests/MaliciousSiteProtectionTests/MaliciousSiteProtectionAPIClientTests.swift b/Tests/MaliciousSiteProtectionTests/MaliciousSiteProtectionAPIClientTests.swift index 52f74b97c..e1fd6db4f 100644 --- a/Tests/MaliciousSiteProtectionTests/MaliciousSiteProtectionAPIClientTests.swift +++ b/Tests/MaliciousSiteProtectionTests/MaliciousSiteProtectionAPIClientTests.swift @@ -16,110 +16,130 @@ // limitations under the License. // import Foundation +import Networking +import TestUtils import XCTest + @testable import MaliciousSiteProtection final class MaliciousSiteProtectionAPIClientTests: XCTestCase { - var mockSession: MockURLSession! + var mockService: MockAPIService! var client: MaliciousSiteProtection.APIClient! override func setUp() { super.setUp() - mockSession = MockURLSession() - client = .init(environment: .staging, session: mockSession) + mockService = MockAPIService() + client = .init(environment: .staging, service: mockService) } override func tearDown() { - mockSession = nil + mockService = nil client = nil super.tearDown() } - func testGetFilterSetSuccess() async { + func testWhenPhishingFilterSetRequestedAndSucceeds_ChangeSetIsReturned() async throws { // Given let insertFilter = Filter(hash: "a379a6f6eeafb9a55e378c118034e2751e682fab9f2d30ab13d2125586ce1947", regex: ".") let deleteFilter = Filter(hash: "6a929cd0b3ba4677eaedf1b2bdaf3ff89281cca94f688c83103bc9a676aea46d", regex: "(?i)^https?\\:\\/\\/[\\w\\-\\.]+(?:\\:(?:80|443))?") - let expectedResponse = APIClient.FiltersChangeSetResponse(insert: [insertFilter], delete: [deleteFilter], revision: 1, replace: false) - mockSession.data = try? JSONEncoder().encode(expectedResponse) - mockSession.response = HTTPURLResponse(url: client.filterSetURL, statusCode: 200, httpVersion: nil, headerFields: nil) + let expectedResponse = APIClient.Response.FiltersChangeSet(insert: [insertFilter], delete: [deleteFilter], revision: 666, replace: false) + mockService.requestHandler = { [unowned self] in + XCTAssertEqual($0.urlRequest.url, client.environment.url(for: .filterSet(.init(threatKind: .phishing, revision: 666)))) + let data = try? JSONEncoder().encode(expectedResponse) + let response = HTTPURLResponse(url: $0.urlRequest.url!, statusCode: 200, httpVersion: nil, headerFields: nil)! + return .success(.init(data: data, httpResponse: response)) + } // When - let response = await client.getFilterSet(revision: 1) + let response = try await client.filtersChangeSet(for: .phishing, revision: 666) // Then XCTAssertEqual(response, expectedResponse) } - func testGetHashPrefixesSuccess() async { + func testWhenHashPrefixesRequestedAndSucceeds_ChangeSetIsReturned() async throws { // Given - let expectedResponse = APIClient.HashPrefixesChangeSetResponse(insert: ["abc"], delete: ["def"], revision: 1, replace: false) - mockSession.data = try? JSONEncoder().encode(expectedResponse) - mockSession.response = HTTPURLResponse(url: client.hashPrefixURL, statusCode: 200, httpVersion: nil, headerFields: nil) + let expectedResponse = APIClient.Response.HashPrefixesChangeSet(insert: ["abc"], delete: ["def"], revision: 1, replace: false) + mockService.requestHandler = { [unowned self] in + XCTAssertEqual($0.urlRequest.url, client.environment.url(for: .hashPrefixSet(.init(threatKind: .phishing, revision: 1)))) + let data = try? JSONEncoder().encode(expectedResponse) + let response = HTTPURLResponse(url: $0.urlRequest.url!, statusCode: 200, httpVersion: nil, headerFields: nil)! + return .success(.init(data: data, httpResponse: response)) + } // When - let response = await client.getHashPrefixes(revision: 1) + let response = try await client.hashPrefixesChangeSet(for: .phishing, revision: 1) // Then XCTAssertEqual(response, expectedResponse) } - func testGetMatchesSuccess() async { + func testWhenMatchesRequestedAndSucceeds_MatchesAreReturned() async throws { // Given - let expectedResponse = APIClient.MatchResponse(matches: [Match(hostname: "example.com", url: "https://example.com/test", regex: ".", hash: "a379a6f6eeafb9a55e378c118034e2751e682fab9f2d30ab13d2125586ce1947", category: nil)]) - mockSession.data = try? JSONEncoder().encode(expectedResponse) - mockSession.response = HTTPURLResponse(url: client.matchesURL, statusCode: 200, httpVersion: nil, headerFields: nil) + let expectedResponse = APIClient.Response.Matches(matches: [Match(hostname: "example.com", url: "https://example.com/test", regex: ".", hash: "a379a6f6eeafb9a55e378c118034e2751e682fab9f2d30ab13d2125586ce1947", category: nil)]) + mockService.requestHandler = { [unowned self] in + XCTAssertEqual($0.urlRequest.url, client.environment.url(for: .matches(.init(hashPrefix: "abc")))) + let data = try? JSONEncoder().encode(expectedResponse) + let response = HTTPURLResponse(url: $0.urlRequest.url!, statusCode: 200, httpVersion: nil, headerFields: nil)! + return .success(.init(data: data, httpResponse: response)) + } // When - let response = await client.getMatches(hashPrefix: "abc") + let response = try await client.matches(forHashPrefix: "abc") // Then - XCTAssertEqual(response, expectedResponse.matches) + XCTAssertEqual(response.matches, expectedResponse.matches) } - func testGetFilterSetInvalidURL() async { + func testWhenHashPrefixesRequestFails_ErrorThrown() async throws { // Given let invalidRevision = -1 + mockService.requestHandler = { + // Simulate a failure or invalid request + let response = HTTPURLResponse(url: $0.urlRequest.url!, statusCode: 400, httpVersion: nil, headerFields: nil)! + return .success(.init(data: nil, httpResponse: response)) + } - // When - let response = await client.getFilterSet(revision: invalidRevision) - - // Then - XCTAssertEqual(response, .init(insert: [], delete: [], revision: invalidRevision, replace: false)) + do { + let response = try await client.hashPrefixesChangeSet(for: .phishing, revision: invalidRevision) + XCTFail("Unexpected \(response) expected throw") + } catch { + } } - func testGetHashPrefixesInvalidURL() async { + func testWhenFilterSetRequestFails_ErrorThrown() async throws { // Given let invalidRevision = -1 + mockService.requestHandler = { + // Simulate a failure or invalid request + let response = HTTPURLResponse(url: $0.urlRequest.url!, statusCode: 400, httpVersion: nil, headerFields: nil)! + return .success(.init(data: nil, httpResponse: response)) + } - // When - let response = await client.getHashPrefixes(revision: invalidRevision) - - // Then - XCTAssertEqual(response, .init(insert: [], delete: [], revision: invalidRevision, replace: false)) + do { + let response = try await client.hashPrefixesChangeSet(for: .phishing, revision: invalidRevision) + XCTFail("Unexpected \(response) expected throw") + } catch { + } } - func testGetMatchesInvalidURL() async { + + func testWhenMatchesRequestFails_ErrorThrown() async throws { // Given let invalidHashPrefix = "" + mockService.requestHandler = { + // Simulate a failure or invalid request + let response = HTTPURLResponse(url: $0.urlRequest.url!, statusCode: 400, httpVersion: nil, headerFields: nil)! + return .success(.init(data: nil, httpResponse: response)) + } - // When - let response = await client.getMatches(hashPrefix: invalidHashPrefix) - - // Then - XCTAssertTrue(response.isEmpty) - } -} - -class MockURLSession: URLSessionProtocol { - var data: Data? - var response: URLResponse? - var error: Error? - - func data(for request: URLRequest) async throws -> (Data, URLResponse) { - if let error = error { - throw error + do { + let response = try await client.matches(forHashPrefix: invalidHashPrefix) + XCTFail("Unexpected \(response) expected throw") + } catch { } - return (data ?? Data(), response ?? URLResponse()) } + } + diff --git a/Tests/MaliciousSiteProtectionTests/Mocks/MockMaliciousSiteDetector.swift b/Tests/MaliciousSiteProtectionTests/Mocks/MockMaliciousSiteDetector.swift deleted file mode 100644 index 0d54cd459..000000000 --- a/Tests/MaliciousSiteProtectionTests/Mocks/MockMaliciousSiteDetector.swift +++ /dev/null @@ -1,38 +0,0 @@ -// -// MockMaliciousSiteDetector.swift -// -// Copyright © 2024 DuckDuckGo. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// - -import Foundation -import MaliciousSiteProtection - -public class MockMaliciousSiteDetector: MaliciousSiteDetecting { - private var mockClient: MaliciousSiteProtection.APIClientProtocol - public var didCallIsMalicious: Bool = false - - init() { - self.mockClient = MockMaliciousSiteProtectionAPIClient() - } - - public func getMatches(hashPrefix: String) async -> Set { - let matches = await mockClient.getMatches(hashPrefix: hashPrefix) - return Set(matches) - } - - public func evaluate(_ url: URL) async -> ThreatKind? { - return url.absoluteString.contains("malicious") ? .phishing : nil - } -} diff --git a/Tests/MaliciousSiteProtectionTests/Mocks/MockMaliciousSiteProtectionAPIClient.swift b/Tests/MaliciousSiteProtectionTests/Mocks/MockMaliciousSiteProtectionAPIClient.swift index ad2a31fe9..d47552491 100644 --- a/Tests/MaliciousSiteProtectionTests/Mocks/MockMaliciousSiteProtectionAPIClient.swift +++ b/Tests/MaliciousSiteProtectionTests/Mocks/MockMaliciousSiteProtectionAPIClient.swift @@ -23,7 +23,7 @@ public class MockMaliciousSiteProtectionAPIClient: MaliciousSiteProtection.APICl public var updateHashPrefixesWasCalled: Bool = false public var updateFilterSetsWasCalled: Bool = false - private var filterRevisions: [Int: APIClient.FiltersChangeSetResponse] = [ + private var filterRevisions: [Int: APIClient.Response.FiltersChangeSet] = [ 0: .init(insert: [ Filter(hash: "testhash1", regex: ".*example.*"), Filter(hash: "testhash2", regex: ".*test.*") @@ -45,7 +45,7 @@ public class MockMaliciousSiteProtectionAPIClient: MaliciousSiteProtection.APICl ], revision: 4, replace: false) ] - private var hashPrefixRevisions: [Int: APIClient.HashPrefixesChangeSetResponse] = [ + private var hashPrefixRevisions: [Int: APIClient.Response.HashPrefixesChangeSet] = [ 0: .init(insert: [ "aa00bb11", "bb00cc11", @@ -65,20 +65,31 @@ public class MockMaliciousSiteProtectionAPIClient: MaliciousSiteProtection.APICl ], revision: 4, replace: false) ] - public func getFilterSet(revision: Int) async -> APIClient.FiltersChangeSetResponse { + public func load(_ requestConfig: Request) async throws -> Request.ResponseType where Request : APIRequestProtocol { + switch requestConfig.requestType { + case .hashPrefixSet(let configuration): + return _hashPrefixesChangeSet(for: configuration.threatKind, revision: configuration.revision ?? 0) as! Request.ResponseType + case .filterSet(let configuration): + return _filtersChangeSet(for: configuration.threatKind, revision: configuration.revision ?? 0) as! Request.ResponseType + case .matches(let configuration): + return _matches(forHashPrefix: configuration.hashPrefix) as! Request.ResponseType + } + } + func _filtersChangeSet(for threatKind: MaliciousSiteProtection.ThreatKind, revision: Int) -> MaliciousSiteProtection.APIClient.Response.FiltersChangeSet { updateFilterSetsWasCalled = true return filterRevisions[revision] ?? .init(insert: [], delete: [], revision: revision, replace: false) } - public func getHashPrefixes(revision: Int) async -> APIClient.HashPrefixesChangeSetResponse { + func _hashPrefixesChangeSet(for threatKind: MaliciousSiteProtection.ThreatKind, revision: Int) -> MaliciousSiteProtection.APIClient.Response.HashPrefixesChangeSet { updateHashPrefixesWasCalled = true return hashPrefixRevisions[revision] ?? .init(insert: [], delete: [], revision: revision, replace: false) } - public func getMatches(hashPrefix: String) async -> [Match] { - return [ + func _matches(forHashPrefix hashPrefix: String) -> APIClient.Response.Matches { + .init(matches: [ Match(hostname: "example.com", url: "https://example.com/mal", regex: ".", hash: "a379a6f6eeafb9a55e378c118034e2751e682fab9f2d30ab13d2125586ce1947", category: nil), Match(hostname: "test.com", url: "https://test.com/mal", regex: ".*test.*", hash: "aa00bb11aa00cc11bb00cc11", category: nil) - ] + ]) } + } diff --git a/Tests/NetworkingTests/v2/APIRequestV2Tests.swift b/Tests/NetworkingTests/v2/APIRequestV2Tests.swift index 59eeadebb..4ec1b8b59 100644 --- a/Tests/NetworkingTests/v2/APIRequestV2Tests.swift +++ b/Tests/NetworkingTests/v2/APIRequestV2Tests.swift @@ -41,21 +41,18 @@ final class APIRequestV2Tests: XCTestCase { cachePolicy: cachePolicy, responseConstraints: constraints) - guard let urlRequest = apiRequest?.urlRequest else { - XCTFail("Nil URLRequest") - return - } + let urlRequest = apiRequest.urlRequest XCTAssertEqual(urlRequest.url?.host(), url.host()) XCTAssertEqual(urlRequest.httpMethod, method.rawValue) let urlComponents = URLComponents(string: urlRequest.url!.absoluteString)! - XCTAssertTrue(urlComponents.queryItems!.contains(queryItems.toURLQueryItems())) + XCTAssertTrue(urlComponents.queryItems!.contains(URLQueryItem(name: "key", value: "value"))) XCTAssertEqual(urlRequest.allHTTPHeaderFields, headers.httpHeaders) XCTAssertEqual(urlRequest.httpBody, body) - XCTAssertEqual(apiRequest?.timeoutInterval, timeoutInterval) + XCTAssertEqual(apiRequest.timeoutInterval, timeoutInterval) XCTAssertEqual(urlRequest.cachePolicy, cachePolicy) - XCTAssertEqual(apiRequest?.responseConstraints, constraints) + XCTAssertEqual(apiRequest.responseConstraints, constraints) } func testURLRequestGeneration() { @@ -75,16 +72,16 @@ final class APIRequestV2Tests: XCTestCase { timeoutInterval: timeoutInterval, cachePolicy: cachePolicy) - let urlComponents = URLComponents(string: apiRequest!.urlRequest.url!.absoluteString)! - XCTAssertTrue(urlComponents.queryItems!.contains(queryItems.toURLQueryItems())) + let urlComponents = URLComponents(string: apiRequest.urlRequest.url!.absoluteString)! + XCTAssertTrue(urlComponents.queryItems!.contains(URLQueryItem(name: "key", value: "value"))) XCTAssertNotNil(apiRequest) - XCTAssertEqual(apiRequest?.urlRequest.url?.absoluteString, "https://www.example.com?key=value") - XCTAssertEqual(apiRequest?.urlRequest.httpMethod, method.rawValue) - XCTAssertEqual(apiRequest?.urlRequest.allHTTPHeaderFields, headers.httpHeaders) - XCTAssertEqual(apiRequest?.urlRequest.httpBody, body) - XCTAssertEqual(apiRequest?.urlRequest.timeoutInterval, timeoutInterval) - XCTAssertEqual(apiRequest?.urlRequest.cachePolicy, cachePolicy) + XCTAssertEqual(apiRequest.urlRequest.url?.absoluteString, "https://www.example.com?key=value") + XCTAssertEqual(apiRequest.urlRequest.httpMethod, method.rawValue) + XCTAssertEqual(apiRequest.urlRequest.allHTTPHeaderFields, headers.httpHeaders) + XCTAssertEqual(apiRequest.urlRequest.httpBody, body) + XCTAssertEqual(apiRequest.urlRequest.timeoutInterval, timeoutInterval) + XCTAssertEqual(apiRequest.urlRequest.cachePolicy, cachePolicy) } func testDefaultValues() { @@ -92,16 +89,13 @@ final class APIRequestV2Tests: XCTestCase { let apiRequest = APIRequestV2(url: url) let headers = APIRequestV2.HeadersV2() - guard let urlRequest = apiRequest?.urlRequest else { - XCTFail("Nil URLRequest") - return - } + let urlRequest = apiRequest.urlRequest XCTAssertEqual(urlRequest.httpMethod, HTTPRequestMethod.get.rawValue) XCTAssertEqual(urlRequest.timeoutInterval, 60.0) XCTAssertEqual(headers.httpHeaders, urlRequest.allHTTPHeaderFields) XCTAssertNil(urlRequest.httpBody) XCTAssertEqual(urlRequest.cachePolicy.rawValue, 0) - XCTAssertNil(apiRequest?.responseConstraints) + XCTAssertNil(apiRequest.responseConstraints) } func testAllowedQueryReservedCharacters() { @@ -112,9 +106,10 @@ final class APIRequestV2Tests: XCTestCase { queryItems: queryItems, allowedQueryReservedCharacters: CharacterSet(charactersIn: ",")) - let urlString = apiRequest!.urlRequest.url!.absoluteString - XCTAssertTrue(urlString == "https://www.example.com?k%2523e,y=val%2523ue") + let urlString = apiRequest.urlRequest.url!.absoluteString + XCTAssertEqual(urlString, "https://www.example.com?k%23e,y=val%23ue") + let urlComponents = URLComponents(string: urlString)! - XCTAssertTrue(urlComponents.queryItems?.count == 1) + XCTAssertEqual(urlComponents.queryItems?.count, 1) } } diff --git a/Tests/NetworkingTests/v2/APIServiceTests.swift b/Tests/NetworkingTests/v2/APIServiceTests.swift index 9cae44323..730d6afbb 100644 --- a/Tests/NetworkingTests/v2/APIServiceTests.swift +++ b/Tests/NetworkingTests/v2/APIServiceTests.swift @@ -41,7 +41,7 @@ final class APIServiceTests: XCTestCase { cachePolicy: .reloadIgnoringLocalAndRemoteCacheData, responseConstraints: [APIResponseConstraints.allowHTTPNotModified, APIResponseConstraints.requireETagHeader], - allowedQueryReservedCharacters: CharacterSet(charactersIn: ","))! + allowedQueryReservedCharacters: CharacterSet(charactersIn: ",")) let apiService = DefaultAPIService() let response = try await apiService.fetch(request: request) let responseHTML: String = try response.decodeBody() @@ -50,7 +50,7 @@ final class APIServiceTests: XCTestCase { func disabled_testRealCallJSON() async throws { // func testRealCallJSON() async throws { - let request = APIRequestV2(url: HTTPURLResponse.testUrl)! + let request = APIRequestV2(url: HTTPURLResponse.testUrl) let apiService = DefaultAPIService() let result = try await apiService.fetch(request: request) @@ -63,7 +63,7 @@ final class APIServiceTests: XCTestCase { func disabled_testRealCallString() async throws { // func testRealCallString() async throws { - let request = APIRequestV2(url: HTTPURLResponse.testUrl)! + let request = APIRequestV2(url: HTTPURLResponse.testUrl) let apiService = DefaultAPIService() let result = try await apiService.fetch(request: request) @@ -75,17 +75,16 @@ final class APIServiceTests: XCTestCase { "qName2": "qValue2"] MockURLProtocol.requestHandler = { request in let urlComponents = URLComponents(string: request.url!.absoluteString)! - XCTAssertTrue(urlComponents.queryItems!.contains(qItems.toURLQueryItems())) + XCTAssertTrue(urlComponents.queryItems!.contains(qItems.map { URLQueryItem(name: $0.key, value: $0.value) })) return (HTTPURLResponse.ok, nil) } - let request = APIRequestV2(url: HTTPURLResponse.testUrl, - queryItems: qItems)! + let request = APIRequestV2(url: HTTPURLResponse.testUrl, queryItems: qItems) let apiService = DefaultAPIService(urlSession: mockURLSession) _ = try await apiService.fetch(request: request) } func testURLRequestError() async throws { - let request = APIRequestV2(url: HTTPURLResponse.testUrl)! + let request = APIRequestV2(url: HTTPURLResponse.testUrl) enum TestError: Error { case anError @@ -111,7 +110,7 @@ final class APIServiceTests: XCTestCase { func testResponseRequirementAllowHTTPNotModifiedSuccess() async throws { let requirements = [APIResponseConstraints.allowHTTPNotModified ] - let request = APIRequestV2(url: HTTPURLResponse.testUrl, responseConstraints: requirements)! + let request = APIRequestV2(url: HTTPURLResponse.testUrl, responseConstraints: requirements) MockURLProtocol.requestHandler = { _ in ( HTTPURLResponse.notModified, Data()) } @@ -122,7 +121,7 @@ final class APIServiceTests: XCTestCase { } func testResponseRequirementAllowHTTPNotModifiedFailure() async throws { - let request = APIRequestV2(url: HTTPURLResponse.testUrl)! + let request = APIRequestV2(url: HTTPURLResponse.testUrl) MockURLProtocol.requestHandler = { _ in ( HTTPURLResponse.notModified, Data()) } @@ -147,7 +146,7 @@ final class APIServiceTests: XCTestCase { let requirements: [APIResponseConstraints] = [ APIResponseConstraints.requireETagHeader ] - let request = APIRequestV2(url: HTTPURLResponse.testUrl, responseConstraints: requirements)! + let request = APIRequestV2(url: HTTPURLResponse.testUrl, responseConstraints: requirements) MockURLProtocol.requestHandler = { _ in ( HTTPURLResponse.ok, nil) } // HTTPURLResponse.ok contains etag let apiService = DefaultAPIService(urlSession: mockURLSession) @@ -158,7 +157,7 @@ final class APIServiceTests: XCTestCase { func testResponseRequirementRequireETagHeaderFailure() async throws { let requirements = [ APIResponseConstraints.requireETagHeader ] - let request = APIRequestV2(url: HTTPURLResponse.testUrl, responseConstraints: requirements)! + let request = APIRequestV2(url: HTTPURLResponse.testUrl, responseConstraints: requirements) MockURLProtocol.requestHandler = { _ in ( HTTPURLResponse.okNoEtag, nil) } @@ -181,7 +180,7 @@ final class APIServiceTests: XCTestCase { func testResponseRequirementRequireUserAgentSuccess() async throws { let requirements = [ APIResponseConstraints.requireUserAgent ] - let request = APIRequestV2(url: HTTPURLResponse.testUrl, responseConstraints: requirements)! + let request = APIRequestV2(url: HTTPURLResponse.testUrl, responseConstraints: requirements) MockURLProtocol.requestHandler = { _ in ( HTTPURLResponse.okUserAgent, nil) @@ -194,7 +193,7 @@ final class APIServiceTests: XCTestCase { func testResponseRequirementRequireUserAgentFailure() async throws { let requirements = [ APIResponseConstraints.requireUserAgent ] - let request = APIRequestV2(url: HTTPURLResponse.testUrl, responseConstraints: requirements)! + let request = APIRequestV2(url: HTTPURLResponse.testUrl, responseConstraints: requirements) MockURLProtocol.requestHandler = { _ in ( HTTPURLResponse.ok, nil) } From 9a32e36c267cf135ae5a48636dadbfeec23da179 Mon Sep 17 00:00:00 2001 From: Alexey Martemyanov Date: Mon, 25 Nov 2024 21:55:58 +0600 Subject: [PATCH 2/6] fix linter issues --- .../MaliciousSiteProtectionAPIClientTests.swift | 2 -- .../Mocks/MockMaliciousSiteProtectionAPIClient.swift | 2 +- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/Tests/MaliciousSiteProtectionTests/MaliciousSiteProtectionAPIClientTests.swift b/Tests/MaliciousSiteProtectionTests/MaliciousSiteProtectionAPIClientTests.swift index e1fd6db4f..d32d264ea 100644 --- a/Tests/MaliciousSiteProtectionTests/MaliciousSiteProtectionAPIClientTests.swift +++ b/Tests/MaliciousSiteProtectionTests/MaliciousSiteProtectionAPIClientTests.swift @@ -124,7 +124,6 @@ final class MaliciousSiteProtectionAPIClientTests: XCTestCase { } } - func testWhenMatchesRequestFails_ErrorThrown() async throws { // Given let invalidHashPrefix = "" @@ -142,4 +141,3 @@ final class MaliciousSiteProtectionAPIClientTests: XCTestCase { } } - diff --git a/Tests/MaliciousSiteProtectionTests/Mocks/MockMaliciousSiteProtectionAPIClient.swift b/Tests/MaliciousSiteProtectionTests/Mocks/MockMaliciousSiteProtectionAPIClient.swift index d47552491..24d1c203b 100644 --- a/Tests/MaliciousSiteProtectionTests/Mocks/MockMaliciousSiteProtectionAPIClient.swift +++ b/Tests/MaliciousSiteProtectionTests/Mocks/MockMaliciousSiteProtectionAPIClient.swift @@ -65,7 +65,7 @@ public class MockMaliciousSiteProtectionAPIClient: MaliciousSiteProtection.APICl ], revision: 4, replace: false) ] - public func load(_ requestConfig: Request) async throws -> Request.ResponseType where Request : APIRequestProtocol { + public func load(_ requestConfig: Request) async throws -> Request.ResponseType where Request: APIRequestProtocol { switch requestConfig.requestType { case .hashPrefixSet(let configuration): return _hashPrefixesChangeSet(for: configuration.threatKind, revision: configuration.revision ?? 0) as! Request.ResponseType From b2bdcc7fedf176fbf167d17bd76e25e488e71b86 Mon Sep 17 00:00:00 2001 From: Alexey Martemyanov Date: Mon, 25 Nov 2024 22:02:49 +0600 Subject: [PATCH 3/6] fix failing test --- Sources/Common/Extensions/URLExtension.swift | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/Sources/Common/Extensions/URLExtension.swift b/Sources/Common/Extensions/URLExtension.swift index 78cc2881f..ce68773d5 100644 --- a/Sources/Common/Extensions/URLExtension.swift +++ b/Sources/Common/Extensions/URLExtension.swift @@ -380,7 +380,8 @@ extension URL { } public func appending(percentEncodedQueryItems: [URLQueryItem]) -> URL { - guard var components = URLComponents(url: self, resolvingAgainstBaseURL: true) else { return self } + guard !percentEncodedQueryItems.isEmpty, + var components = URLComponents(url: self, resolvingAgainstBaseURL: true) else { return self } var existingPercentEncodedQueryItems = components.percentEncodedQueryItems ?? [URLQueryItem]() existingPercentEncodedQueryItems.append(contentsOf: percentEncodedQueryItems) From f82742b77a42adc602ed0df00e90a041c3a10264 Mon Sep 17 00:00:00 2001 From: Alexey Martemyanov Date: Thu, 28 Nov 2024 19:28:26 +0600 Subject: [PATCH 4/6] add Timeout support to APIClient env; set "matches" timeout to 1 --- Sources/MaliciousSiteProtection/API/APIClient.swift | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/Sources/MaliciousSiteProtection/API/APIClient.swift b/Sources/MaliciousSiteProtection/API/APIClient.swift index 678171839..e9ab09548 100644 --- a/Sources/MaliciousSiteProtection/API/APIClient.swift +++ b/Sources/MaliciousSiteProtection/API/APIClient.swift @@ -32,6 +32,7 @@ public extension APIClientProtocol where Self == APIClient { public protocol APIClientEnvironment { func headers(for request: APIClient.Request) -> APIRequestV2.HeadersV2 func url(for request: APIClient.Request) -> URL + func timeout(for request: APIClient.Request) -> TimeInterval } public extension APIClient { @@ -85,6 +86,15 @@ public extension APIClient { public func headers(for request: APIClient.Request) -> APIRequestV2.HeadersV2 { defaultHeaders } + + public func timeout(for request: APIClient.Request) -> URL { + switch request { + case .hashPrefixSet, .filterSet: 60 + // This could block navigation so we should favour navigation loading if the backend is degraded. + // On Android we're looking at a maximum 1 second timeout for this request. + case .matches: 1 + } + } } } @@ -107,8 +117,9 @@ public struct APIClient: APIClientProtocol { let requestType = requestConfig.requestType let headers = environment.headers(for: requestType) let url = environment.url(for: requestType) + let timeout = environment.timeout(for: requestType) - let apiRequest = APIRequestV2(url: url, method: .get, headers: headers) + let apiRequest = APIRequestV2(url: url, method: .get, headers: headers, timeoutInterval: timeout) let response = try await service.fetch(request: apiRequest) let result: Request.ResponseType = try response.decodeBody() From ec8699139ec8109902035416d0278196bd179d9e Mon Sep 17 00:00:00 2001 From: Alexey Martemyanov Date: Thu, 28 Nov 2024 19:30:02 +0600 Subject: [PATCH 5/6] drop `dev` env --- Sources/MaliciousSiteProtection/API/APIClient.swift | 2 -- 1 file changed, 2 deletions(-) diff --git a/Sources/MaliciousSiteProtection/API/APIClient.swift b/Sources/MaliciousSiteProtection/API/APIClient.swift index e9ab09548..76e9dfc26 100644 --- a/Sources/MaliciousSiteProtection/API/APIClient.swift +++ b/Sources/MaliciousSiteProtection/API/APIClient.swift @@ -40,13 +40,11 @@ public extension APIClient { case production case staging - case dev var endpoint: URL { switch self { case .production: URL(string: "https://duckduckgo.com/api/protection/")! case .staging: URL(string: "https://staging.duckduckgo.com/api/protection/")! - case .dev: URL(string: "https://4842-20-93-28-24.ngrok-free.app/api/protection/")! } } From e8d4572c1dfe3d515e016866b4c2f0fc21b11917 Mon Sep 17 00:00:00 2001 From: Alexey Martemyanov Date: Thu, 28 Nov 2024 19:32:48 +0600 Subject: [PATCH 6/6] fix build --- Sources/MaliciousSiteProtection/API/APIClient.swift | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Sources/MaliciousSiteProtection/API/APIClient.swift b/Sources/MaliciousSiteProtection/API/APIClient.swift index 76e9dfc26..f4c08e446 100644 --- a/Sources/MaliciousSiteProtection/API/APIClient.swift +++ b/Sources/MaliciousSiteProtection/API/APIClient.swift @@ -85,7 +85,7 @@ public extension APIClient { defaultHeaders } - public func timeout(for request: APIClient.Request) -> URL { + public func timeout(for request: APIClient.Request) -> TimeInterval { switch request { case .hashPrefixSet, .filterSet: 60 // This could block navigation so we should favour navigation loading if the backend is degraded.