Skip to content

Commit

Permalink
fix: session refresh loop in all request interceptors
Browse files Browse the repository at this point in the history
  • Loading branch information
anku255 committed Jun 5, 2024
1 parent 2854c17 commit 2fdb97b
Show file tree
Hide file tree
Showing 9 changed files with 234 additions and 7 deletions.
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## [Unreleased]

## [0.4.0] - 2024-06-05

### Changes

- Fixed the session refresh loop in all the request interceptors that occurred when an API returned a 401 response despite a valid session. Interceptors now attempt to refresh the session a maximum of ten times before throwing an error. The retry limit is configurable via the `maxRetryAttemptsForSessionRefresh` option.

## [0.3.2] - 2024-05-28

- Readds FDI 2.0 and 3.0 support
Expand Down
2 changes: 1 addition & 1 deletion SuperTokensIOS.podspec
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

Pod::Spec.new do |s|
s.name = 'SuperTokensIOS'
s.version = "0.3.2"
s.version = "0.4.0"
s.summary = 'SuperTokens SDK for using login and session management functionality in iOS apps'

# This description is used to generate tags and improve search results.
Expand Down
1 change: 1 addition & 0 deletions SuperTokensIOS/Classes/Error.swift
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ public enum SuperTokensError: Error {
case apiError(message: String)
case generalError(message: String)
case illegalAccess(message: String)
case maxRetryAttemptsReachedForSessionRefresh(message: String)
}

internal enum SDKFailableError: Error {
Expand Down
4 changes: 2 additions & 2 deletions SuperTokensIOS/Classes/SuperTokens.swift
Original file line number Diff line number Diff line change
Expand Up @@ -46,12 +46,12 @@ public class SuperTokens {
FrontToken.setItem(frontToken: "remove")
}

public static func initialize(apiDomain: String, apiBasePath: String? = nil, sessionExpiredStatusCode: Int? = nil, sessionTokenBackendDomain: String? = nil, tokenTransferMethod: SuperTokensTokenTransferMethod? = nil, userDefaultsSuiteName: String? = nil, eventHandler: ((EventType) -> Void)? = nil, preAPIHook: ((APIAction, URLRequest) -> URLRequest)? = nil, postAPIHook: ((APIAction, URLRequest, URLResponse?) -> Void)? = nil) throws {
public static func initialize(apiDomain: String, apiBasePath: String? = nil, sessionExpiredStatusCode: Int? = nil, sessionTokenBackendDomain: String? = nil, maxRetryAttemptsForSessionRefresh: Int? = nil, tokenTransferMethod: SuperTokensTokenTransferMethod? = nil, userDefaultsSuiteName: String? = nil, eventHandler: ((EventType) -> Void)? = nil, preAPIHook: ((APIAction, URLRequest) -> URLRequest)? = nil, postAPIHook: ((APIAction, URLRequest, URLResponse?) -> Void)? = nil) throws {
if SuperTokens.isInitCalled {
return;
}

SuperTokens.config = try NormalisedInputType.normaliseInputType(apiDomain: apiDomain, apiBasePath: apiBasePath, sessionExpiredStatusCode: sessionExpiredStatusCode, sessionTokenBackendDomain: sessionTokenBackendDomain, tokenTransferMethod: tokenTransferMethod, eventHandler: eventHandler, preAPIHook: preAPIHook, postAPIHook: postAPIHook, userDefaultsSuiteName: userDefaultsSuiteName)
SuperTokens.config = try NormalisedInputType.normaliseInputType(apiDomain: apiDomain, apiBasePath: apiBasePath, sessionExpiredStatusCode: sessionExpiredStatusCode, maxRetryAttemptsForSessionRefresh: maxRetryAttemptsForSessionRefresh, sessionTokenBackendDomain: sessionTokenBackendDomain, tokenTransferMethod: tokenTransferMethod, eventHandler: eventHandler, preAPIHook: preAPIHook, postAPIHook: postAPIHook, userDefaultsSuiteName: userDefaultsSuiteName)

guard let _config: NormalisedInputType = SuperTokens.config else {
throw SuperTokensError.initError(message: "Error initialising SuperTokens")
Expand Down
15 changes: 15 additions & 0 deletions SuperTokensIOS/Classes/SuperTokensURLProtocol.swift
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import Foundation

public class SuperTokensURLProtocol: URLProtocol {
private static let readWriteDispatchQueue = DispatchQueue(label: "io.supertokens.session.readwrite", attributes: .concurrent)
private var sessionRefreshAttempts = 0

// Refer to comment in makeRequest to know why this is needed
private var requestForRetry: NSMutableURLRequest? = nil
Expand Down Expand Up @@ -122,10 +123,24 @@ public class SuperTokensURLProtocol: URLProtocol {
)

if httpResponse.statusCode == SuperTokens.config!.sessionExpiredStatusCode {
/**
* An API may return a 401 error response even with a valid session, causing a session refresh loop in the interceptor.
* To prevent this infinite loop, we break out of the loop after retrying the original request a specified number of times.
* The maximum number of retry attempts is defined by maxRetryAttemptsForSessionRefresh config variable.
*/
if self.sessionRefreshAttempts >= SuperTokens.config!.maxRetryAttemptsForSessionRefresh {
let errorMessage = "Error: Received 401 response from \(String(describing: apiRequest.url)). After refreshing the session and retrying the request \(SuperTokens.config!.maxRetryAttemptsForSessionRefresh ) times, we still received 401 responses. Maximum session refresh limit reached. Breaking out of the refresh loop. Please investigate your API. Consider increasing maxRetryAttemptsForSessionRefresh in the config if needed."
print(errorMessage)
self.resolveToUser(data: nil, response: nil, error: SuperTokensError.maxRetryAttemptsReachedForSessionRefresh(message: errorMessage))
return
}

mutableRequest = self.removeAuthHeaderIfMatchesLocalToken(_mutableRequest: mutableRequest)
SuperTokensURLProtocol.onUnauthorisedResponse(preRequestLocalSessionState: preRequestLocalSessionState, callback: {
unauthResponse in

self.sessionRefreshAttempts += 1;

if unauthResponse.status == .RETRY {
self.requestForRetry = mutableRequest
self.makeRequest()
Expand Down
19 changes: 16 additions & 3 deletions SuperTokensIOS/Classes/Utils.swift
Original file line number Diff line number Diff line change
Expand Up @@ -43,17 +43,25 @@ class NormalisedInputType {
var apiDomain: String
var apiBasePath: String
var sessionExpiredStatusCode: Int
/**
* This specifies the maximum number of times the interceptor will attempt to refresh
* the session when a 401 Unauthorized response is received. If the number of retries
* exceeds this limit, no further attempts will be made to refresh the session, and
* and an error will be thrown.
*/
var maxRetryAttemptsForSessionRefresh: Int
var sessionTokenBackendDomain: String?
var eventHandler: (EventType) -> Void
var preAPIHook: (APIAction, URLRequest) -> URLRequest
var postAPIHook: (APIAction, URLRequest, URLResponse?) -> Void
var userDefaultsSuiteName: String?
var tokenTransferMethod: SuperTokensTokenTransferMethod

init(apiDomain: String, apiBasePath: String, sessionExpiredStatusCode: Int, sessionTokenBackendDomain: String?, tokenTransferMethod: SuperTokensTokenTransferMethod, eventHandler: @escaping (EventType) -> Void, preAPIHook: @escaping (APIAction, URLRequest) -> URLRequest, postAPIHook: @escaping (APIAction, URLRequest, URLResponse?) -> Void, userDefaultsSuiteName: String?) {
init(apiDomain: String, apiBasePath: String, sessionExpiredStatusCode: Int, maxRetryAttemptsForSessionRefresh: Int, sessionTokenBackendDomain: String?, tokenTransferMethod: SuperTokensTokenTransferMethod, eventHandler: @escaping (EventType) -> Void, preAPIHook: @escaping (APIAction, URLRequest) -> URLRequest, postAPIHook: @escaping (APIAction, URLRequest, URLResponse?) -> Void, userDefaultsSuiteName: String?) {
self.apiDomain = apiDomain
self.apiBasePath = apiBasePath
self.sessionExpiredStatusCode = sessionExpiredStatusCode
self.maxRetryAttemptsForSessionRefresh = maxRetryAttemptsForSessionRefresh
self.sessionTokenBackendDomain = sessionTokenBackendDomain
self.eventHandler = eventHandler
self.preAPIHook = preAPIHook
Expand Down Expand Up @@ -98,7 +106,7 @@ class NormalisedInputType {
return noDotNormalised
}

internal static func normaliseInputType(apiDomain: String, apiBasePath: String?, sessionExpiredStatusCode: Int?, sessionTokenBackendDomain: String?, tokenTransferMethod: SuperTokensTokenTransferMethod?, eventHandler: ((EventType) -> Void)?, preAPIHook: ((APIAction, URLRequest) -> URLRequest)?, postAPIHook: ((APIAction, URLRequest, URLResponse?) -> Void)?, userDefaultsSuiteName: String?) throws -> NormalisedInputType {
internal static func normaliseInputType(apiDomain: String, apiBasePath: String?, sessionExpiredStatusCode: Int?, maxRetryAttemptsForSessionRefresh: Int? = nil, sessionTokenBackendDomain: String?, tokenTransferMethod: SuperTokensTokenTransferMethod?, eventHandler: ((EventType) -> Void)?, preAPIHook: ((APIAction, URLRequest) -> URLRequest)?, postAPIHook: ((APIAction, URLRequest, URLResponse?) -> Void)?, userDefaultsSuiteName: String?) throws -> NormalisedInputType {
let _apiDomain = try NormalisedURLDomain(url: apiDomain)
var _apiBasePath = try NormalisedURLPath(input: "/auth")

Expand All @@ -111,6 +119,11 @@ class NormalisedInputType {
_sessionExpiredStatusCode = sessionExpiredStatusCode!
}

var _maxRetryAttemptsForSessionRefresh: Int = 10
if maxRetryAttemptsForSessionRefresh != nil {
_maxRetryAttemptsForSessionRefresh = maxRetryAttemptsForSessionRefresh!
}

var _sessionTokenBackendDomain: String? = nil
if sessionTokenBackendDomain != nil {
_sessionTokenBackendDomain = try normaliseSessionScopeOrThrowError(sessionScope: sessionTokenBackendDomain!)
Expand Down Expand Up @@ -144,7 +157,7 @@ class NormalisedInputType {
}


return NormalisedInputType(apiDomain: _apiDomain.getAsStringDangerous(), apiBasePath: _apiBasePath.getAsStringDangerous(), sessionExpiredStatusCode: _sessionExpiredStatusCode, sessionTokenBackendDomain: _sessionTokenBackendDomain, tokenTransferMethod: _tokenTransferMethod, eventHandler: _eventHandler, preAPIHook: _preAPIHook, postAPIHook: _postApiHook, userDefaultsSuiteName: userDefaultsSuiteName)
return NormalisedInputType(apiDomain: _apiDomain.getAsStringDangerous(), apiBasePath: _apiBasePath.getAsStringDangerous(), sessionExpiredStatusCode: _sessionExpiredStatusCode, maxRetryAttemptsForSessionRefresh: _maxRetryAttemptsForSessionRefresh, sessionTokenBackendDomain: _sessionTokenBackendDomain, tokenTransferMethod: _tokenTransferMethod, eventHandler: _eventHandler, preAPIHook: _preAPIHook, postAPIHook: _postApiHook, userDefaultsSuiteName: userDefaultsSuiteName)
}
}

Expand Down
2 changes: 1 addition & 1 deletion SuperTokensIOS/Classes/Version.swift
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,5 @@ import Foundation

internal class Version {
static let supported_fdi: [String] = ["1.16", "1.17", "1.18", "1.19", "2.0", "3.0"]
static let sdkVersion = "0.3.2"
static let sdkVersion = "0.4.0"
}
5 changes: 5 additions & 0 deletions testHelpers/server/index.js
Original file line number Diff line number Diff line change
Expand Up @@ -507,6 +507,10 @@ app.get("/testError", (req, res) => {
res.status(500).send("test error message");
});

app.get("/throw-401", (req, res) => {
res.status(401).send("Unauthorised");
});

app.get("/stop", async (req, res) => {
process.exit();
});
Expand Down Expand Up @@ -579,6 +583,7 @@ app.use("*", async (req, res, next) => {
app.use(errorHandler());

app.use(async (err, req, res, next) => {
console.log("err", err);
res.send(500).send(err);
});

Expand Down
187 changes: 187 additions & 0 deletions testHelpers/testapp/Tests/sessionTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -1260,4 +1260,191 @@ class sessionTests: XCTestCase {
//
// XCTAssertTrue(failureMessage == nil, failureMessage ?? "")
// }

func testBreakOutOfSessionRefreshLoopAfterDefaultMaxRetryAttempts() {
TestUtils.startST()

var failureMessage: String? = nil;
do {
try SuperTokens.initialize(apiDomain: testAPIBase, tokenTransferMethod: .cookie)
} catch {
failureMessage = "supertokens init failed"
}

let requestSemaphore = DispatchSemaphore(value: 0)

// Step 1: Login request
URLSession.shared.dataTask(with: TestUtils.getLoginRequest(), completionHandler: { data, response, error in
if error != nil {
failureMessage = "login API error"
requestSemaphore.signal()
return
}

if let httpResponse = response as? HTTPURLResponse {
if httpResponse.statusCode != 200 {
failureMessage = "Login response code is not 200";
requestSemaphore.signal()
} else {
let throw401URL = URL(string: "\(testAPIBase)/throw-401")!
var throw401Request = URLRequest(url: throw401URL)
throw401Request.httpMethod = "GET"

URLSession.shared.dataTask(with: throw401Request, completionHandler: { data, response, error in
if let error = error {

if (error as NSError).code != 4 {
failureMessage = "Expected the error code to be 4 (maxRetryAttemptsReachedForSessionRefresh)"
requestSemaphore.signal()
return;
}


let count = TestUtils.getRefreshTokenCounter()
if count != 10 {
failureMessage = "Expected refresh to be called 10 times but it was called " + String(count) + " times"
}
requestSemaphore.signal()
} else {
failureMessage = "Expected /throw-401 request to throw error"
requestSemaphore.signal()
}
}).resume()
}
} else {
failureMessage = "Login response is nil"
requestSemaphore.signal()
}
}).resume()

_ = requestSemaphore.wait(timeout: DispatchTime.distantFuture)


XCTAssertTrue(failureMessage == nil, failureMessage ?? "")
}

func testBreakOutOfSessionRefreshLoopAfterConfiguredMaxRetryAttempts() {
TestUtils.startST()

var failureMessage: String? = nil;
do {
try SuperTokens.initialize(apiDomain: testAPIBase, maxRetryAttemptsForSessionRefresh: 5, tokenTransferMethod: .cookie)
} catch {
failureMessage = "supertokens init failed"
}

let requestSemaphore = DispatchSemaphore(value: 0)

// Step 1: Login request
URLSession.shared.dataTask(with: TestUtils.getLoginRequest(), completionHandler: { data, response, error in
if error != nil {
failureMessage = "login API error"
requestSemaphore.signal()
return
}

if let httpResponse = response as? HTTPURLResponse {
if httpResponse.statusCode != 200 {
failureMessage = "Login response code is not 200";
requestSemaphore.signal()
} else {
let throw401URL = URL(string: "\(testAPIBase)/throw-401")!
var throw401Request = URLRequest(url: throw401URL)
throw401Request.httpMethod = "GET"

URLSession.shared.dataTask(with: throw401Request, completionHandler: { data, response, error in
if let error = error {

if (error as NSError).code != 4 {
failureMessage = "Expected the error code to be 4 (maxRetryAttemptsReachedForSessionRefresh)"
requestSemaphore.signal()
return;
}


let count = TestUtils.getRefreshTokenCounter()
if count != 5 {
failureMessage = "Expected refresh to be called 5 times but it was called " + String(count) + " times"
}
requestSemaphore.signal()
} else {
failureMessage = "Expected /throw-401 request to throw error"
requestSemaphore.signal()
}
}).resume()
}
} else {
failureMessage = "Login response is nil"
requestSemaphore.signal()
}
}).resume()

_ = requestSemaphore.wait(timeout: DispatchTime.distantFuture)


XCTAssertTrue(failureMessage == nil, failureMessage ?? "")
}

func testShouldNotDoSessionRefreshIfMaxRetryAttemptsForSessionRefreshIsZero() {
TestUtils.startST()

var failureMessage: String? = nil;
do {
try SuperTokens.initialize(apiDomain: testAPIBase, maxRetryAttemptsForSessionRefresh: 0, tokenTransferMethod: .cookie)
} catch {
failureMessage = "supertokens init failed"
}

let requestSemaphore = DispatchSemaphore(value: 0)

// Step 1: Login request
URLSession.shared.dataTask(with: TestUtils.getLoginRequest(), completionHandler: { data, response, error in
if error != nil {
failureMessage = "login API error"
requestSemaphore.signal()
return
}

if let httpResponse = response as? HTTPURLResponse {
if httpResponse.statusCode != 200 {
failureMessage = "Login response code is not 200";
requestSemaphore.signal()
} else {
let throw401URL = URL(string: "\(testAPIBase)/throw-401")!
var throw401Request = URLRequest(url: throw401URL)
throw401Request.httpMethod = "GET"

URLSession.shared.dataTask(with: throw401Request, completionHandler: { data, response, error in
if let error = error {

if (error as NSError).code != 4 {
failureMessage = "Expected the error code to be 4 (maxRetryAttemptsReachedForSessionRefresh)"
requestSemaphore.signal()
return;
}


let count = TestUtils.getRefreshTokenCounter()
if count != 0 {
failureMessage = "Expected refresh to be called 0 times but it was called " + String(count) + " times"
}
requestSemaphore.signal()
} else {
failureMessage = "Expected /throw-401 request to throw error"
requestSemaphore.signal()
}
}).resume()
}
} else {
failureMessage = "Login response is nil"
requestSemaphore.signal()
}
}).resume()

_ = requestSemaphore.wait(timeout: DispatchTime.distantFuture)


XCTAssertTrue(failureMessage == nil, failureMessage ?? "")
}

}

0 comments on commit 2fdb97b

Please sign in to comment.