Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Malware protection 2: refactor APIClient #1092

Merged
merged 6 commits into from
Nov 28, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Package.swift
Original file line number Diff line number Diff line change
Expand Up @@ -651,6 +651,7 @@ let package = Package(
.testTarget(
name: "MaliciousSiteProtectionTests",
dependencies: [
"TestUtils",
"MaliciousSiteProtection",
],
resources: [
Expand Down
28 changes: 16 additions & 12 deletions Sources/Common/Extensions/URLExtension.swift
Original file line number Diff line number Diff line change
Expand Up @@ -354,22 +354,24 @@ extension URL {

// MARK: - Parameters

@_disfavoredOverload // prefer ordered KeyValuePairs collection when `parameters` passed as a Dictionary literal to preserve order.
public func appendingParameters<QueryParams: Collection>(_ parameters: QueryParams, allowedReservedCharacters: CharacterSet? = nil) -> URL
where QueryParams.Element == (key: String, value: String) {
let result = self.appending(percentEncodedQueryItems: parameters.map { name, value in
URLQueryItem(percentEncodingName: name, value: value, withAllowedCharacters: allowedReservedCharacters)
})
return result
}

return parameters.reduce(self) { partialResult, parameter in
partialResult.appendingParameter(
name: parameter.key,
value: parameter.value,
allowedReservedCharacters: allowedReservedCharacters
)
}
public func appendingParameters(_ parameters: KeyValuePairs<String, String>, allowedReservedCharacters: CharacterSet? = nil) -> URL {
let result = self.appending(percentEncodedQueryItems: parameters.map { name, value in
URLQueryItem(percentEncodingName: name, value: value, withAllowedCharacters: allowedReservedCharacters)
})
return result
}

public func appendingParameter(name: String, value: String, allowedReservedCharacters: CharacterSet? = nil) -> URL {
let queryItem = URLQueryItem(percentEncodingName: name,
value: value,
withAllowedCharacters: allowedReservedCharacters)
let queryItem = URLQueryItem(percentEncodingName: name, value: value, withAllowedCharacters: allowedReservedCharacters)
return self.appending(percentEncodedQueryItem: queryItem)
}

Expand All @@ -378,13 +380,15 @@ extension URL {
}

public func appending(percentEncodedQueryItems: [URLQueryItem]) -> URL {
guard var components = URLComponents(url: self, resolvingAgainstBaseURL: true) else { return self }
guard !percentEncodedQueryItems.isEmpty,
var components = URLComponents(url: self, resolvingAgainstBaseURL: true) else { return self }

var existingPercentEncodedQueryItems = components.percentEncodedQueryItems ?? [URLQueryItem]()
existingPercentEncodedQueryItems.append(contentsOf: percentEncodedQueryItems)
components.percentEncodedQueryItems = existingPercentEncodedQueryItems
let result = components.url ?? self

return components.url ?? self
return result
}

public func getQueryItems() -> [URLQueryItem]? {
Expand Down
172 changes: 81 additions & 91 deletions Sources/MaliciousSiteProtection/API/APIClient.swift
Original file line number Diff line number Diff line change
Expand Up @@ -18,129 +18,119 @@

import Common
import Foundation
import os
import Networking

public protocol APIClientProtocol {
func getFilterSet(revision: Int) async -> APIClient.FiltersChangeSetResponse
func getHashPrefixes(revision: Int) async -> APIClient.HashPrefixesChangeSetResponse
func getMatches(hashPrefix: String) async -> [Match]
func load<Request: APIRequestProtocol>(_ requestConfig: Request) async throws -> Request.ResponseType
}

public protocol URLSessionProtocol {
func data(for request: URLRequest) async throws -> (Data, URLResponse)
public extension APIClientProtocol where Self == APIClient {
static var production: APIClientProtocol { APIClient(environment: .production) }
static var staging: APIClientProtocol { APIClient(environment: .staging) }
}

extension URLSession: URLSessionProtocol {}

extension URLSessionProtocol {
public static var defaultSession: URLSessionProtocol {
return URLSession.shared
}
public protocol APIClientEnvironment {
func headers(for request: APIClient.Request) -> APIRequestV2.HeadersV2
func url(for request: APIClient.Request) -> URL
}

public struct APIClient: APIClientProtocol {
public extension APIClient {
enum DefaultEnvironment: APIClientEnvironment {

public enum Environment {
case production
case staging
}
case dev

enum Constants {
static let productionEndpoint = URL(string: "https://duckduckgo.com/api/protection/")!
static let stagingEndpoint = URL(string: "https://staging.duckduckgo.com/api/protection/")!
enum APIPath: String {
case filterSet
case hashPrefix
case matches
var endpoint: URL {
switch self {
case .production: URL(string: "https://duckduckgo.com/api/protection/")!
case .staging: URL(string: "https://staging.duckduckgo.com/api/protection/")!
case .dev: URL(string: "https://4842-20-93-28-24.ngrok-free.app/api/protection/")!
}
}
}

private let endpointURL: URL
private let session: URLSessionProtocol!
private var headers: [String: String]? = [:]
var defaultHeaders: APIRequestV2.HeadersV2 {
.init(userAgent: APIRequest.Headers.userAgent)
}

var filterSetURL: URL {
endpointURL.appendingPathComponent(Constants.APIPath.filterSet.rawValue)
}
enum APIPath {
static let filterSet = "filterSet"
static let hashPrefix = "hashPrefix"
static let matches = "matches"
}

var hashPrefixURL: URL {
endpointURL.appendingPathComponent(Constants.APIPath.hashPrefix.rawValue)
}
enum QueryParameter {
static let category = "category"
static let revision = "revision"
static let hashPrefix = "hashPrefix"
}

var matchesURL: URL {
endpointURL.appendingPathComponent(Constants.APIPath.matches.rawValue)
}
public func url(for request: APIClient.Request) -> URL {
switch request {
case .hashPrefixSet(let configuration):
endpoint.appendingPathComponent(APIPath.hashPrefix).appendingParameters([
QueryParameter.category: configuration.threatKind.rawValue,
QueryParameter.revision: (configuration.revision ?? 0).description,
])
case .filterSet(let configuration):
endpoint.appendingPathComponent(APIPath.filterSet).appendingParameters([
QueryParameter.category: configuration.threatKind.rawValue,
QueryParameter.revision: (configuration.revision ?? 0).description,
])
case .matches(let configuration):
endpoint.appendingPathComponent(APIPath.matches).appendingParameter(name: QueryParameter.hashPrefix, value: configuration.hashPrefix)
}
}

public init(environment: Environment = .production, session: URLSessionProtocol = URLSession.defaultSession) {
switch environment {
case .production:
endpointURL = Constants.productionEndpoint
case .staging:
endpointURL = Constants.stagingEndpoint
public func headers(for request: APIClient.Request) -> APIRequestV2.HeadersV2 {
defaultHeaders
}
self.session = session
}

public func getFilterSet(revision: Int) async -> FiltersChangeSetResponse {
guard let url = createURL(for: .filterSet, revision: revision) else {
logDebug("🔸 Invalid filterSet revision URL: \(revision)")
return FiltersChangeSetResponse(insert: [], delete: [], revision: revision, replace: false)
}
return await fetch(url: url, responseType: FiltersChangeSetResponse.self) ?? FiltersChangeSetResponse(insert: [], delete: [], revision: revision, replace: false)
}

public struct APIClient: APIClientProtocol {

let environment: APIClientEnvironment
private let service: APIService

public init(environment: Self.DefaultEnvironment = .production, service: APIService = DefaultAPIService(urlSession: .shared)) {
self.init(environment: environment as APIClientEnvironment, service: service)
}

public func getHashPrefixes(revision: Int) async -> HashPrefixesChangeSetResponse {
guard let url = createURL(for: .hashPrefix, revision: revision) else {
logDebug("🔸 Invalid hashPrefix revision URL: \(revision)")
return HashPrefixesChangeSetResponse(insert: [], delete: [], revision: revision, replace: false)
}
return await fetch(url: url, responseType: HashPrefixesChangeSetResponse.self) ?? HashPrefixesChangeSetResponse(insert: [], delete: [], revision: revision, replace: false)
public init(environment: APIClientEnvironment, service: APIService) {
self.environment = environment
self.service = service
}

public func getMatches(hashPrefix: String) async -> [Match] {
let queryItems = [URLQueryItem(name: "hashPrefix", value: hashPrefix)]
guard let url = createURL(for: .matches, queryItems: queryItems) else {
logDebug("🔸 Invalid matches URL: \(hashPrefix)")
return []
}
return await fetch(url: url, responseType: MatchResponse.self)?.matches ?? []
public func load<Request: APIRequestProtocol>(_ requestConfig: Request) async throws -> Request.ResponseType {
let requestType = requestConfig.requestType
let headers = environment.headers(for: requestType)
let url = environment.url(for: requestType)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are there any specific client caching configs we should consider here? The server should return a TTL of 10 minutes for the API so I imagine this will be automatically respected here, but just checking.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I‘d say no as the caching policies should be set on the backend side: if there‘s caching it should be fine as all of our requests are parametrized (i.e. revision, hash etc)


let apiRequest = APIRequestV2(url: url, method: .get, headers: headers)
let response = try await service.fetch(request: apiRequest)
let result: Request.ResponseType = try response.decodeBody()

return result
}
}

// MARK: Private Methods
extension APIClient {
}

private func logDebug(_ message: String) {
Logger.api.debug("\(message)")
// MARK: - Convenience
extension APIClientProtocol {
public func filtersChangeSet(for threatKind: ThreatKind, revision: Int) async throws -> APIClient.Response.FiltersChangeSet {
let result = try await load(.filterSet(threatKind: threatKind, revision: revision))
return result
}

private func createURL(for path: Constants.APIPath, revision: Int? = nil, queryItems: [URLQueryItem]? = nil) -> URL? {
// Start with the base URL and append the path component
var urlComponents = URLComponents(url: endpointURL.appendingPathComponent(path.rawValue), resolvingAgainstBaseURL: true)
var items = queryItems ?? []
if let revision = revision, revision > 0 {
items.append(URLQueryItem(name: "revision", value: String(revision)))
}
urlComponents?.queryItems = items.isEmpty ? nil : items
return urlComponents?.url
public func hashPrefixesChangeSet(for threatKind: ThreatKind, revision: Int) async throws -> APIClient.Response.HashPrefixesChangeSet {
let result = try await load(.hashPrefixes(threatKind: threatKind, revision: revision))
return result
}

private func fetch<T: Decodable>(url: URL, responseType: T.Type) async -> T? {
var request = URLRequest(url: url)
request.httpMethod = "GET"
request.allHTTPHeaderFields = headers

do {
let (data, _) = try await session.data(for: request)
if let response = try? JSONDecoder().decode(responseType, from: data) {
return response
} else {
logDebug("🔸 Failed to decode response for \(String(describing: responseType)): \(data)")
}
} catch {
logDebug("🔴 Failed to load \(String(describing: responseType)) data: \(error)")
}
return nil
public func matches(forHashPrefix hashPrefix: String) async throws -> APIClient.Response.Matches {
let result = try await load(.matches(hashPrefix: hashPrefix))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we pass a shorter timeout to the /matches endpoint? This could block navigation so we should favour navigation loading if the backend is degraded. On Android we're looking at a maximum 1 second timeout for this request. The other requests can be heavier but they are background tasks anyway so not as risky.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added parametrized Timeout

return result
}
}
84 changes: 84 additions & 0 deletions Sources/MaliciousSiteProtection/API/APIRequest.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
//
// APIRequest.swift
//
// Copyright © 2024 DuckDuckGo. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//

import Foundation

public protocol APIRequestProtocol {
associatedtype ResponseType: Decodable
var requestType: APIClient.Request { get }
}

public extension APIClient {
enum Request {
case hashPrefixSet(HashPrefixes)
case filterSet(FilterSet)
case matches(Matches)
}
}
public extension APIClient.Request {
struct HashPrefixes: APIRequestProtocol {
public typealias ResponseType = APIClient.Response.HashPrefixesChangeSet

public let threatKind: ThreatKind
public let revision: Int?

public var requestType: APIClient.Request {
.hashPrefixSet(self)
}
}
}
extension APIRequestProtocol where Self == APIClient.Request.HashPrefixes {
static func hashPrefixes(threatKind: ThreatKind, revision: Int?) -> Self {
.init(threatKind: threatKind, revision: revision)
}
}

public extension APIClient.Request {
struct FilterSet: APIRequestProtocol {
public typealias ResponseType = APIClient.Response.FiltersChangeSet

public let threatKind: ThreatKind
public let revision: Int?

public var requestType: APIClient.Request {
.filterSet(self)
}
}
}
extension APIRequestProtocol where Self == APIClient.Request.FilterSet {
static func filterSet(threatKind: ThreatKind, revision: Int?) -> Self {
.init(threatKind: threatKind, revision: revision)
}
}

public extension APIClient.Request {
struct Matches: APIRequestProtocol {
public typealias ResponseType = APIClient.Response.Matches

public let hashPrefix: String

public var requestType: APIClient.Request {
.matches(self)
}
}
}
extension APIRequestProtocol where Self == APIClient.Request.Matches {
static func matches(hashPrefix: String) -> Self {
.init(hashPrefix: hashPrefix)
}
}
7 changes: 5 additions & 2 deletions Sources/MaliciousSiteProtection/API/ChangeSetResponse.swift
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,10 @@ extension APIClient {
}
}

public typealias FiltersChangeSetResponse = ChangeSetResponse<Filter>
public typealias HashPrefixesChangeSetResponse = ChangeSetResponse<String>
public enum Response {
public typealias FiltersChangeSet = ChangeSetResponse<Filter>
public typealias HashPrefixesChangeSet = ChangeSetResponse<String>
public typealias Matches = MatchResponse
}

}
4 changes: 4 additions & 0 deletions Sources/MaliciousSiteProtection/API/MatchResponse.swift
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@ extension APIClient {

public struct MatchResponse: Codable, Equatable {
public var matches: [Match]

public init(matches: [Match]) {
self.matches = matches
}
}

}
Loading
Loading