diff --git a/Sources/BrowserServicesKit/LinkProtection/LinkProtection.swift b/Sources/BrowserServicesKit/LinkProtection/LinkProtection.swift index 907316b03..fe23aeb1a 100644 --- a/Sources/BrowserServicesKit/LinkProtection/LinkProtection.swift +++ b/Sources/BrowserServicesKit/LinkProtection/LinkProtection.swift @@ -36,10 +36,26 @@ public struct LinkProtection { errorReporting: errorReporting) } + private func makeNewRequest(changingUrl url: URL, inRequest request: URLRequest) -> URLRequest { + var newRequest = request + newRequest.url = url + return newRequest + } + public mutating func setMainFrameUrl(_ url: URL?) { mainFrameUrl = url } + public func getCleanURLRequest(from urlRequest: URLRequest, + onStartExtracting: () -> Void, + onFinishExtracting: @escaping () -> Void, + completion: @escaping (URLRequest) -> Void) { + getCleanURL(from: urlRequest.url!, onStartExtracting: onStartExtracting, onFinishExtracting: onFinishExtracting) { newUrl in + let newRequest = makeNewRequest(changingUrl: newUrl, inRequest: urlRequest) + completion(newRequest) + } + } + public func getCleanURL(from url: URL, onStartExtracting: () -> Void, onFinishExtracting: @escaping () -> Void, @@ -77,11 +93,12 @@ public struct LinkProtection { // swiftlint:disable function_parameter_count public func requestTrackingLinkRewrite(initiatingURL: URL?, - destinationURL: URL, + destinationRequest: URLRequest, onStartExtracting: () -> Void, onFinishExtracting: @escaping () -> Void, - onLinkRewrite: @escaping (URL) -> Void, + onLinkRewrite: @escaping (URLRequest) -> Void, policyDecisionHandler: @escaping (Bool) -> Void) -> Bool { + let destinationURL = destinationRequest.url if let mainFrameUrl = mainFrameUrl, destinationURL != mainFrameUrl { // If mainFrameUrl is set and is different from destinationURL we will assume this is a redirect // We do not rewrite redirects due to breakage concerns @@ -91,7 +108,7 @@ public struct LinkProtection { var didRewriteLink = false if let newURL = linkCleaner.extractCanonicalFromAMPLink(initiator: initiatingURL, destination: destinationURL) { policyDecisionHandler(false) - onLinkRewrite(newURL) + onLinkRewrite(makeNewRequest(changingUrl: newURL, inRequest: destinationRequest)) didRewriteLink = true } else if ampExtractor.urlContainsAMPKeyword(destinationURL) { onStartExtracting() @@ -103,13 +120,13 @@ public struct LinkProtection { } policyDecisionHandler(false) - onLinkRewrite(canonical) + onLinkRewrite(makeNewRequest(changingUrl: canonical, inRequest: destinationRequest)) } didRewriteLink = true } else if let newURL = linkCleaner.cleanTrackingParameters(initiator: initiatingURL, url: destinationURL) { if newURL != destinationURL { policyDecisionHandler(false) - onLinkRewrite(newURL) + onLinkRewrite(makeNewRequest(changingUrl: newURL, inRequest: destinationRequest)) didRewriteLink = true } } @@ -121,10 +138,10 @@ public struct LinkProtection { navigationAction: WKNavigationAction, onStartExtracting: () -> Void, onFinishExtracting: @escaping () -> Void, - onLinkRewrite: @escaping (URL, WKNavigationAction) -> Void, + onLinkRewrite: @escaping (URLRequest, WKNavigationAction) -> Void, policyDecisionHandler: @escaping (WKNavigationActionPolicy) -> Void) -> Bool { requestTrackingLinkRewrite(initiatingURL: initiatingURL, - destinationURL: navigationAction.request.url!, + destinationRequest: navigationAction.request, onStartExtracting: onStartExtracting, onFinishExtracting: onFinishExtracting, onLinkRewrite: { onLinkRewrite($0, navigationAction) }, @@ -134,13 +151,13 @@ public struct LinkProtection { @MainActor public func requestTrackingLinkRewrite(initiatingURL: URL?, - destinationURL: URL, + destinationRequest: URLRequest, onStartExtracting: () -> Void, onFinishExtracting: @escaping () -> Void, - onLinkRewrite: @escaping (URL) -> Void) async -> Bool? { + onLinkRewrite: @escaping (URLRequest) -> Void) async -> Bool? { await withCheckedContinuation { continuation in let didRewriteLink = requestTrackingLinkRewrite(initiatingURL: initiatingURL, - destinationURL: destinationURL, + destinationRequest: destinationRequest, onStartExtracting: onStartExtracting, onFinishExtracting: onFinishExtracting, onLinkRewrite: onLinkRewrite) { navigationActionPolicy in