diff --git a/Sources/Subscription/AccountManager.swift b/Sources/Subscription/AccountManager.swift index 53b2b1832..5cd74cfa3 100644 --- a/Sources/Subscription/AccountManager.swift +++ b/Sources/Subscription/AccountManager.swift @@ -178,6 +178,7 @@ public class AccountManager: AccountManaging { do { try storage.clearAuthenticationState() try accessTokenStorage.removeAccessToken() + SubscriptionService.signOut() entitlementsCache.reset() } catch { if let error = error as? AccountKeychainAccessError { @@ -219,8 +220,8 @@ public class AccountManager: AccountManaging { case noCachedData } - public func hasEntitlement(for entitlement: Entitlement.ProductName) async -> Result { - switch await fetchEntitlements() { + public func hasEntitlement(for entitlement: Entitlement.ProductName, cachePolicy: CachePolicy = .returnCacheDataElseLoad) async -> Result { + switch await fetchEntitlements(cachePolicy: cachePolicy) { case .success(let entitlements): return .success(entitlements.compactMap { $0.product }.contains(entitlement)) case .failure(let error): @@ -251,9 +252,9 @@ public class AccountManager: AccountManaging { } } - public func fetchEntitlements(policy: CachePolicy = .returnCacheDataElseLoad) async -> Result<[Entitlement], Error> { + public func fetchEntitlements(cachePolicy: CachePolicy = .returnCacheDataElseLoad) async -> Result<[Entitlement], Error> { - switch policy { + switch cachePolicy { case .reloadIgnoringLocalCacheData: return await fetchRemoteEntitlements() diff --git a/Sources/Subscription/Flows/AppStore/AppStorePurchaseFlow.swift b/Sources/Subscription/Flows/AppStore/AppStorePurchaseFlow.swift index b0ce6a943..2b3bb65fa 100644 --- a/Sources/Subscription/Flows/AppStore/AppStorePurchaseFlow.swift +++ b/Sources/Subscription/Flows/AppStore/AppStorePurchaseFlow.swift @@ -65,6 +65,9 @@ public final class AppStorePurchaseFlow { let accountManager = AccountManager(subscriptionAppGroup: subscriptionAppGroup) let externalID: String + // Clear the Subscription cache + SubscriptionService.signOut() + // Check for past transactions most recent switch await AppStoreRestoreFlow.restoreAccountFromPastPurchase(subscriptionAppGroup: subscriptionAppGroup) { case .success: diff --git a/Sources/Subscription/Flows/AppStore/AppStoreRestoreFlow.swift b/Sources/Subscription/Flows/AppStore/AppStoreRestoreFlow.swift index 9bff11b15..eec52b7b6 100644 --- a/Sources/Subscription/Flows/AppStore/AppStoreRestoreFlow.swift +++ b/Sources/Subscription/Flows/AppStore/AppStoreRestoreFlow.swift @@ -78,7 +78,7 @@ public final class AppStoreRestoreFlow { var isSubscriptionActive = false - switch await SubscriptionService.getSubscription(accessToken: accessToken) { + switch await SubscriptionService.getSubscription(accessToken: accessToken, cachePolicy: .reloadIgnoringLocalCacheData) { case .success(let subscription): isSubscriptionActive = subscription.isActive case .failure: diff --git a/Sources/Subscription/Services/Model/Entitlement.swift b/Sources/Subscription/Services/Model/Entitlement.swift index 284f949a9..ece4b618a 100644 --- a/Sources/Subscription/Services/Model/Entitlement.swift +++ b/Sources/Subscription/Services/Model/Entitlement.swift @@ -19,7 +19,6 @@ import Foundation public struct Entitlement: Codable, Equatable { - let id: Int let name: String public let product: ProductName diff --git a/Sources/Subscription/Services/SubscriptionService.swift b/Sources/Subscription/Services/SubscriptionService.swift index 1248bde1c..030d28251 100644 --- a/Sources/Subscription/Services/SubscriptionService.swift +++ b/Sources/Subscription/Services/SubscriptionService.swift @@ -20,7 +20,7 @@ import Common import Foundation import Macros -public struct SubscriptionService: APIService { +public final class SubscriptionService: APIService { public static let session = { let configuration = URLSessionConfiguration.ephemeral @@ -36,24 +36,60 @@ public struct SubscriptionService: APIService { } } - // MARK: - + private static let subscriptionCache = UserDefaultsCache(key: UserDefaultsCacheKey.subscription) + + public enum CachePolicy { + case reloadIgnoringLocalCacheData + case returnCacheDataElseLoad + case returnCacheDataDontLoad + } + + public enum SubscriptionServiceError: Error { + case noCachedData + case apiError(APIServiceError) + } + + // MARK: - Subscription fetching with caching - public static func getSubscription(accessToken: String) async -> Result { + private static func getRemoteSubscription(accessToken: String) async -> Result { let result: Result = await executeAPICall(method: "GET", endpoint: "subscription", headers: makeAuthorizationHeader(for: accessToken)) switch result { - case .success(let response): - cachedGetSubscriptionResponse = response - case .failure: - cachedGetSubscriptionResponse = nil + case .success(let subscriptionResponse): + subscriptionCache.set(subscriptionResponse) + return .success(subscriptionResponse) + case .failure(let error): + return .failure(.apiError(error)) } + } - return result + public static func getSubscription(accessToken: String, cachePolicy: CachePolicy = .returnCacheDataElseLoad) async -> Result { + + switch cachePolicy { + case .reloadIgnoringLocalCacheData: + return await getRemoteSubscription(accessToken: accessToken) + + case .returnCacheDataElseLoad: + if let cachedSubscription = subscriptionCache.get() { + return .success(cachedSubscription) + } else { + return await getRemoteSubscription(accessToken: accessToken) + } + + case .returnCacheDataDontLoad: + if let cachedSubscription = subscriptionCache.get() { + return .success(cachedSubscription) + } else { + return .failure(.noCachedData) + } + } } - public typealias GetSubscriptionResponse = Subscription + public static func signOut() { + subscriptionCache.reset() + } - public static var cachedGetSubscriptionResponse: GetSubscriptionResponse? + public typealias GetSubscriptionResponse = Subscription // MARK: - diff --git a/Sources/Subscription/UserDefaultsCache.swift b/Sources/Subscription/UserDefaultsCache.swift index 86d5fdf2b..a2daf720c 100644 --- a/Sources/Subscription/UserDefaultsCache.swift +++ b/Sources/Subscription/UserDefaultsCache.swift @@ -18,28 +18,56 @@ import Foundation +public struct UserDefaultsCacheSettings { + + // Default expiration interval set to 24 hours + public let defaultExpirationInterval: TimeInterval + + public init(defaultExpirationInterval: TimeInterval = 24 * 60 * 60) { + self.defaultExpirationInterval = defaultExpirationInterval + } +} + public enum UserDefaultsCacheKey: String { - case subscriptionEntitlements + case subscriptionEntitlements = "com.duckduckgo.bsk.subscription.entitlements" + case subscription = "com.duckduckgo.bsk.subscription.info" } -/// A generic UserDefaults cache for storing and retrieving Codable objects. +/// A generic UserDefaults cache for storing and retrieving Codable objects public class UserDefaultsCache { - private var subscriptionAppGroup: String - private lazy var userDefaults: UserDefaults? = UserDefaults(suiteName: subscriptionAppGroup) + + private struct CacheObject: Codable { + let expires: Date + let object: ObjectType + } + + private var subscriptionAppGroup: String? + private var settings: UserDefaultsCacheSettings + private lazy var userDefaults: UserDefaults? = { + if let appGroup = subscriptionAppGroup { + return UserDefaults(suiteName: appGroup) + } else { + return UserDefaults.standard + } + }() + private let key: UserDefaultsCacheKey - public init(subscriptionAppGroup: String, key: UserDefaultsCacheKey) { + public init(subscriptionAppGroup: String? = nil, key: UserDefaultsCacheKey, + settings: UserDefaultsCacheSettings = UserDefaultsCacheSettings()) { self.subscriptionAppGroup = subscriptionAppGroup self.key = key + self.settings = settings } - public func set(_ object: ObjectType) { + public func set(_ object: ObjectType, expires: Date = Date().addingTimeInterval(UserDefaultsCacheSettings().defaultExpirationInterval)) { + let cacheObject = CacheObject(expires: expires, object: object) let encoder = JSONEncoder() do { - let data = try encoder.encode(object) + let data = try encoder.encode(cacheObject) userDefaults?.set(data, forKey: key.rawValue) } catch { - assertionFailure("Failed to encode object of type \(ObjectType.self): \(error)") + assertionFailure("Failed to encode CacheObject: \(error)") } } @@ -47,10 +75,14 @@ public class UserDefaultsCache { guard let data = userDefaults?.data(forKey: key.rawValue) else { return nil } let decoder = JSONDecoder() do { - let object = try decoder.decode(ObjectType.self, from: data) - return object + let cacheObject = try decoder.decode(CacheObject.self, from: data) + if cacheObject.expires > Date() { + return cacheObject.object + } else { + reset() // Clear expired data + return nil + } } catch { - assertionFailure("Failed to decode object of type \(ObjectType.self): \(error)") return nil } } @@ -58,5 +90,4 @@ public class UserDefaultsCache { public func reset() { userDefaults?.removeObject(forKey: key.rawValue) } - }