Skip to content

Commit

Permalink
refactor APIClient
Browse files Browse the repository at this point in the history
  • Loading branch information
mallexxx committed Nov 25, 2024
1 parent 993a45f commit 50ed41c
Show file tree
Hide file tree
Showing 18 changed files with 387 additions and 315 deletions.
1 change: 1 addition & 0 deletions Package.swift
Original file line number Diff line number Diff line change
Expand Up @@ -651,6 +651,7 @@ let package = Package(
.testTarget(
name: "MaliciousSiteProtectionTests",
dependencies: [
"TestUtils",
"MaliciousSiteProtection",
],
resources: [
Expand Down
25 changes: 14 additions & 11 deletions Sources/Common/Extensions/URLExtension.swift
Original file line number Diff line number Diff line change
Expand Up @@ -354,22 +354,24 @@ extension URL {

// MARK: - Parameters

@_disfavoredOverload // prefer ordered KeyValuePairs collection when `parameters` passed as a Dictionary literal to preserve order.
public func appendingParameters<QueryParams: Collection>(_ parameters: QueryParams, allowedReservedCharacters: CharacterSet? = nil) -> URL
where QueryParams.Element == (key: String, value: String) {
let result = self.appending(percentEncodedQueryItems: parameters.map { name, value in
URLQueryItem(percentEncodingName: name, value: value, withAllowedCharacters: allowedReservedCharacters)
})
return result
}

return parameters.reduce(self) { partialResult, parameter in
partialResult.appendingParameter(
name: parameter.key,
value: parameter.value,
allowedReservedCharacters: allowedReservedCharacters
)
}
public func appendingParameters(_ parameters: KeyValuePairs<String, String>, allowedReservedCharacters: CharacterSet? = nil) -> URL {
let result = self.appending(percentEncodedQueryItems: parameters.map { name, value in
URLQueryItem(percentEncodingName: name, value: value, withAllowedCharacters: allowedReservedCharacters)
})
return result
}

public func appendingParameter(name: String, value: String, allowedReservedCharacters: CharacterSet? = nil) -> URL {
let queryItem = URLQueryItem(percentEncodingName: name,
value: value,
withAllowedCharacters: allowedReservedCharacters)
let queryItem = URLQueryItem(percentEncodingName: name, value: value, withAllowedCharacters: allowedReservedCharacters)
return self.appending(percentEncodedQueryItem: queryItem)
}

Expand All @@ -383,8 +385,9 @@ extension URL {
var existingPercentEncodedQueryItems = components.percentEncodedQueryItems ?? [URLQueryItem]()
existingPercentEncodedQueryItems.append(contentsOf: percentEncodedQueryItems)
components.percentEncodedQueryItems = existingPercentEncodedQueryItems
let result = components.url ?? self

return components.url ?? self
return result
}

public func getQueryItems() -> [URLQueryItem]? {
Expand Down
172 changes: 81 additions & 91 deletions Sources/MaliciousSiteProtection/API/APIClient.swift
Original file line number Diff line number Diff line change
Expand Up @@ -18,129 +18,119 @@

import Common
import Foundation
import os
import Networking

public protocol APIClientProtocol {
func getFilterSet(revision: Int) async -> APIClient.FiltersChangeSetResponse
func getHashPrefixes(revision: Int) async -> APIClient.HashPrefixesChangeSetResponse
func getMatches(hashPrefix: String) async -> [Match]
func load<Request: APIRequestProtocol>(_ requestConfig: Request) async throws -> Request.ResponseType
}

public protocol URLSessionProtocol {
func data(for request: URLRequest) async throws -> (Data, URLResponse)
public extension APIClientProtocol where Self == APIClient {
static var production: APIClientProtocol { APIClient(environment: .production) }
static var staging: APIClientProtocol { APIClient(environment: .staging) }
}

extension URLSession: URLSessionProtocol {}

extension URLSessionProtocol {
public static var defaultSession: URLSessionProtocol {
return URLSession.shared
}
public protocol APIClientEnvironment {
func headers(for request: APIClient.Request) -> APIRequestV2.HeadersV2
func url(for request: APIClient.Request) -> URL
}

public struct APIClient: APIClientProtocol {
public extension APIClient {
enum DefaultEnvironment: APIClientEnvironment {

public enum Environment {
case production
case staging
}
case dev

enum Constants {
static let productionEndpoint = URL(string: "https://duckduckgo.com/api/protection/")!
static let stagingEndpoint = URL(string: "https://staging.duckduckgo.com/api/protection/")!
enum APIPath: String {
case filterSet
case hashPrefix
case matches
var endpoint: URL {
switch self {
case .production: URL(string: "https://duckduckgo.com/api/protection/")!
case .staging: URL(string: "https://staging.duckduckgo.com/api/protection/")!
case .dev: URL(string: "https://4842-20-93-28-24.ngrok-free.app/api/protection/")!
}
}
}

private let endpointURL: URL
private let session: URLSessionProtocol!
private var headers: [String: String]? = [:]
var defaultHeaders: APIRequestV2.HeadersV2 {
.init(userAgent: APIRequest.Headers.userAgent)
}

var filterSetURL: URL {
endpointURL.appendingPathComponent(Constants.APIPath.filterSet.rawValue)
}
enum APIPath {
static let filterSet = "filterSet"
static let hashPrefix = "hashPrefix"
static let matches = "matches"
}

var hashPrefixURL: URL {
endpointURL.appendingPathComponent(Constants.APIPath.hashPrefix.rawValue)
}
enum QueryParameter {
static let category = "category"
static let revision = "revision"
static let hashPrefix = "hashPrefix"
}

var matchesURL: URL {
endpointURL.appendingPathComponent(Constants.APIPath.matches.rawValue)
}
public func url(for request: APIClient.Request) -> URL {
switch request {
case .hashPrefixSet(let configuration):
endpoint.appendingPathComponent(APIPath.hashPrefix).appendingParameters([
QueryParameter.category: configuration.threatKind.rawValue,
QueryParameter.revision: (configuration.revision ?? 0).description,
])
case .filterSet(let configuration):
endpoint.appendingPathComponent(APIPath.filterSet).appendingParameters([
QueryParameter.category: configuration.threatKind.rawValue,
QueryParameter.revision: (configuration.revision ?? 0).description,
])
case .matches(let configuration):
endpoint.appendingPathComponent(APIPath.matches).appendingParameter(name: QueryParameter.hashPrefix, value: configuration.hashPrefix)
}
}

public init(environment: Environment = .production, session: URLSessionProtocol = URLSession.defaultSession) {
switch environment {
case .production:
endpointURL = Constants.productionEndpoint
case .staging:
endpointURL = Constants.stagingEndpoint
public func headers(for request: APIClient.Request) -> APIRequestV2.HeadersV2 {
defaultHeaders
}
self.session = session
}

public func getFilterSet(revision: Int) async -> FiltersChangeSetResponse {
guard let url = createURL(for: .filterSet, revision: revision) else {
logDebug("🔸 Invalid filterSet revision URL: \(revision)")
return FiltersChangeSetResponse(insert: [], delete: [], revision: revision, replace: false)
}
return await fetch(url: url, responseType: FiltersChangeSetResponse.self) ?? FiltersChangeSetResponse(insert: [], delete: [], revision: revision, replace: false)
}

public struct APIClient: APIClientProtocol {

let environment: APIClientEnvironment
private let service: APIService

public init(environment: Self.DefaultEnvironment = .production, service: APIService = DefaultAPIService(urlSession: .shared)) {
self.init(environment: environment as APIClientEnvironment, service: service)
}

public func getHashPrefixes(revision: Int) async -> HashPrefixesChangeSetResponse {
guard let url = createURL(for: .hashPrefix, revision: revision) else {
logDebug("🔸 Invalid hashPrefix revision URL: \(revision)")
return HashPrefixesChangeSetResponse(insert: [], delete: [], revision: revision, replace: false)
}
return await fetch(url: url, responseType: HashPrefixesChangeSetResponse.self) ?? HashPrefixesChangeSetResponse(insert: [], delete: [], revision: revision, replace: false)
public init(environment: APIClientEnvironment, service: APIService) {
self.environment = environment
self.service = service
}

public func getMatches(hashPrefix: String) async -> [Match] {
let queryItems = [URLQueryItem(name: "hashPrefix", value: hashPrefix)]
guard let url = createURL(for: .matches, queryItems: queryItems) else {
logDebug("🔸 Invalid matches URL: \(hashPrefix)")
return []
}
return await fetch(url: url, responseType: MatchResponse.self)?.matches ?? []
public func load<Request: APIRequestProtocol>(_ requestConfig: Request) async throws -> Request.ResponseType {
let requestType = requestConfig.requestType
let headers = environment.headers(for: requestType)
let url = environment.url(for: requestType)

let apiRequest = APIRequestV2(url: url, method: .get, headers: headers)
let response = try await service.fetch(request: apiRequest)
let result: Request.ResponseType = try response.decodeBody()

return result
}
}

// MARK: Private Methods
extension APIClient {
}

private func logDebug(_ message: String) {
Logger.api.debug("\(message)")
// MARK: - Convenience
extension APIClientProtocol {
public func filtersChangeSet(for threatKind: ThreatKind, revision: Int) async throws -> APIClient.Response.FiltersChangeSet {
let result = try await load(.filterSet(threatKind: threatKind, revision: revision))
return result
}

private func createURL(for path: Constants.APIPath, revision: Int? = nil, queryItems: [URLQueryItem]? = nil) -> URL? {
// Start with the base URL and append the path component
var urlComponents = URLComponents(url: endpointURL.appendingPathComponent(path.rawValue), resolvingAgainstBaseURL: true)
var items = queryItems ?? []
if let revision = revision, revision > 0 {
items.append(URLQueryItem(name: "revision", value: String(revision)))
}
urlComponents?.queryItems = items.isEmpty ? nil : items
return urlComponents?.url
public func hashPrefixesChangeSet(for threatKind: ThreatKind, revision: Int) async throws -> APIClient.Response.HashPrefixesChangeSet {
let result = try await load(.hashPrefixes(threatKind: threatKind, revision: revision))
return result
}

private func fetch<T: Decodable>(url: URL, responseType: T.Type) async -> T? {
var request = URLRequest(url: url)
request.httpMethod = "GET"
request.allHTTPHeaderFields = headers

do {
let (data, _) = try await session.data(for: request)
if let response = try? JSONDecoder().decode(responseType, from: data) {
return response
} else {
logDebug("🔸 Failed to decode response for \(String(describing: responseType)): \(data)")
}
} catch {
logDebug("🔴 Failed to load \(String(describing: responseType)) data: \(error)")
}
return nil
public func matches(forHashPrefix hashPrefix: String) async throws -> APIClient.Response.Matches {
let result = try await load(.matches(hashPrefix: hashPrefix))
return result
}
}
84 changes: 84 additions & 0 deletions Sources/MaliciousSiteProtection/API/APIRequest.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
//
// APIRequest.swift
//
// Copyright © 2024 DuckDuckGo. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//

import Foundation

public protocol APIRequestProtocol {
associatedtype ResponseType: Decodable
var requestType: APIClient.Request { get }
}

public extension APIClient {
enum Request {
case hashPrefixSet(HashPrefixes)
case filterSet(FilterSet)
case matches(Matches)
}
}
public extension APIClient.Request {
struct HashPrefixes: APIRequestProtocol {
public typealias ResponseType = APIClient.Response.HashPrefixesChangeSet

public let threatKind: ThreatKind
public let revision: Int?

public var requestType: APIClient.Request {
.hashPrefixSet(self)
}
}
}
extension APIRequestProtocol where Self == APIClient.Request.HashPrefixes {
static func hashPrefixes(threatKind: ThreatKind, revision: Int?) -> Self {
.init(threatKind: threatKind, revision: revision)
}
}

public extension APIClient.Request {
struct FilterSet: APIRequestProtocol {
public typealias ResponseType = APIClient.Response.FiltersChangeSet

public let threatKind: ThreatKind
public let revision: Int?

public var requestType: APIClient.Request {
.filterSet(self)
}
}
}
extension APIRequestProtocol where Self == APIClient.Request.FilterSet {
static func filterSet(threatKind: ThreatKind, revision: Int?) -> Self {
.init(threatKind: threatKind, revision: revision)
}
}

public extension APIClient.Request {
struct Matches: APIRequestProtocol {
public typealias ResponseType = APIClient.Response.Matches

public let hashPrefix: String

public var requestType: APIClient.Request {
.matches(self)
}
}
}
extension APIRequestProtocol where Self == APIClient.Request.Matches {
static func matches(hashPrefix: String) -> Self {
.init(hashPrefix: hashPrefix)
}
}
7 changes: 5 additions & 2 deletions Sources/MaliciousSiteProtection/API/ChangeSetResponse.swift
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,10 @@ extension APIClient {
}
}

public typealias FiltersChangeSetResponse = ChangeSetResponse<Filter>
public typealias HashPrefixesChangeSetResponse = ChangeSetResponse<String>
public enum Response {
public typealias FiltersChangeSet = ChangeSetResponse<Filter>
public typealias HashPrefixesChangeSet = ChangeSetResponse<String>
public typealias Matches = MatchResponse
}

}
4 changes: 4 additions & 0 deletions Sources/MaliciousSiteProtection/API/MatchResponse.swift
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@ extension APIClient {

public struct MatchResponse: Codable, Equatable {
public var matches: [Match]

public init(matches: [Match]) {
self.matches = matches
}
}

}
12 changes: 7 additions & 5 deletions Sources/MaliciousSiteProtection/MaliciousSiteDetector.swift
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,6 @@ public final class MaliciousSiteDetector: MaliciousSiteDetecting {
self.eventMapping = eventMapping
}

private func getMatches(hashPrefix: String) async -> Set<Match> {
return Set(await apiClient.getMatches(hashPrefix: hashPrefix))
}

private func inFilterSet(hash: String) -> Set<Filter> {
return Set(dataManager.filterSet.filter { $0.hash == hash })
}
Expand All @@ -65,7 +61,13 @@ public final class MaliciousSiteDetector: MaliciousSiteDetecting {
}

private func fetchMatches(hashPrefix: String) async -> [Match] {
return await apiClient.getMatches(hashPrefix: hashPrefix)
do {
let response = try await apiClient.matches(forHashPrefix: hashPrefix)
return response.matches
} catch {
Logger.api.error("Failed to fetch matches for hash prefix: \(hashPrefix): \(error.localizedDescription)")
return []
}
}

private func checkLocalFilters(canonicalHost: String, canonicalUrl: URL) -> Bool {
Expand Down
Loading

0 comments on commit 50ed41c

Please sign in to comment.