diff --git a/src/main/java/io/supertokens/totp/Totp.java b/src/main/java/io/supertokens/totp/Totp.java index 320d8b513..57f31744e 100644 --- a/src/main/java/io/supertokens/totp/Totp.java +++ b/src/main/java/io/supertokens/totp/Totp.java @@ -187,120 +187,118 @@ private static void checkAndStoreCode(TenantIdentifierWithStorage tenantIdentifi TOTPSQLStorage totpSQLStorage = tenantIdentifierWithStorage.getTOTPStorage(); - while (true) { - try { - totpSQLStorage.startTransaction(con -> { - try { - TOTPUsedCode[] usedCodes = totpSQLStorage.getAllUsedCodesDescOrder_Transaction(con, - tenantIdentifierWithStorage, - userId); - - // N represents # of invalid attempts that will trigger rate limiting: - int N = Config.getConfig(tenantIdentifierWithStorage, main).getTotpMaxAttempts(); // (Default 5) - // Count # of contiguous invalids in latest N attempts (stop at first valid): - long invalidOutOfN = Arrays.stream(usedCodes).limit(N).takeWhile(usedCode -> !usedCode.isValid) - .count(); - int rateLimitResetTimeInMs = Config.getConfig(tenantIdentifierWithStorage, main) - .getTotpRateLimitCooldownTimeSec() * - 1000; // (Default 15 mins) - - // Check if the user has been rate limited: - if (invalidOutOfN == N) { - // All of the latest N attempts were invalid: - long latestInvalidCodeCreatedTime = usedCodes[0].createdTime; - long now = System.currentTimeMillis(); - - if (now - latestInvalidCodeCreatedTime < rateLimitResetTimeInMs) { - // Less than rateLimitResetTimeInMs (default = 15 mins) time has elasped since - // the last invalid code: - long timeLeftMs = (rateLimitResetTimeInMs - (now - latestInvalidCodeCreatedTime)); - throw new StorageTransactionLogicException(new LimitReachedException(timeLeftMs)); - - // If we insert the used code here, then it will further delay the user from - // being able to login. So not inserting it here. - } - } + try { + totpSQLStorage.startTransaction(con -> { + try { + TOTPUsedCode[] usedCodes = totpSQLStorage.getAllUsedCodesDescOrder_Transaction(con, + tenantIdentifierWithStorage, + userId); + + // N represents # of invalid attempts that will trigger rate limiting: + int N = Config.getConfig(tenantIdentifierWithStorage, main).getTotpMaxAttempts(); // (Default 5) + // Count # of contiguous invalids in latest N attempts (stop at first valid): + long invalidOutOfN = Arrays.stream(usedCodes).limit(N).takeWhile(usedCode -> !usedCode.isValid) + .count(); + int rateLimitResetTimeInMs = Config.getConfig(tenantIdentifierWithStorage, main) + .getTotpRateLimitCooldownTimeSec() * + 1000; // (Default 15 mins) + + // Check if the user has been rate limited: + if (invalidOutOfN == N) { + // All of the latest N attempts were invalid: + long latestInvalidCodeCreatedTime = usedCodes[0].createdTime; + long now = System.currentTimeMillis(); - // Check if the code is valid for any device: - boolean isValid = false; - TOTPDevice matchingDevice = null; - for (TOTPDevice device : devices) { - // Check if the code is valid for this device: - if (checkCode(device, code)) { - isValid = true; - matchingDevice = device; - break; - } - } + if (now - latestInvalidCodeCreatedTime < rateLimitResetTimeInMs) { + // Less than rateLimitResetTimeInMs (default = 15 mins) time has elasped since + // the last invalid code: + long timeLeftMs = (rateLimitResetTimeInMs - (now - latestInvalidCodeCreatedTime)); + throw new StorageTransactionLogicException(new LimitReachedException(timeLeftMs, (int)invalidOutOfN, N)); - // Check if the code has been previously used by the user and it was valid (and - // is still valid). If so, this could be a replay attack. So reject it. - if (isValid) { - for (TOTPUsedCode usedCode : usedCodes) { - // One edge case is that if the user has 2 devices, and they are used back to - // back (within 90 seconds) such that the code of the first device was - // regenerated by the second device, then it won't allow the second device's - // code to be used until it is expired. - // But this would be rare so we can ignore it for now. - if (usedCode.code.equals(code) && usedCode.isValid - && usedCode.expiryTime > System.currentTimeMillis()) { - isValid = false; - // We found a matching device but the code - // will be considered invalid here. - } - } + // If we insert the used code here, then it will further delay the user from + // being able to login. So not inserting it here. } + } - // Insert the code into the list of used codes: - - // If device is found, calculate used code expiry time for that device (based on - // its period and skew). Otherwise, use the max used code expiry time of all the - // devices. - int maxUsedCodeExpiry = Arrays.stream(devices) - .mapToInt(device -> device.period * (2 * device.skew + 1)) - .max() - .orElse(0); - int expireInSec = (matchingDevice != null) - ? matchingDevice.period * (2 * matchingDevice.skew + 1) - : maxUsedCodeExpiry; - - long now = System.currentTimeMillis(); - TOTPUsedCode newCode = new TOTPUsedCode(userId, - code, - isValid, now + 1000 * expireInSec, now); - try { - totpSQLStorage.insertUsedCode_Transaction(con, tenantIdentifierWithStorage, newCode); - totpSQLStorage.commitTransaction(con); - } catch (UsedCodeAlreadyExistsException | UnknownTotpUserIdException e) { - throw new StorageTransactionLogicException(e); + // Check if the code is valid for any device: + boolean isValid = false; + TOTPDevice matchingDevice = null; + for (TOTPDevice device : devices) { + // Check if the code is valid for this device: + if (checkCode(device, code)) { + isValid = true; + matchingDevice = device; + break; } + } - if (!isValid) { - // transaction has been committed, so we can directly throw the exception: - throw new StorageTransactionLogicException(new InvalidTotpException()); + // Check if the code has been previously used by the user and it was valid (and + // is still valid). If so, this could be a replay attack. So reject it. + if (isValid) { + for (TOTPUsedCode usedCode : usedCodes) { + // One edge case is that if the user has 2 devices, and they are used back to + // back (within 90 seconds) such that the code of the first device was + // regenerated by the second device, then it won't allow the second device's + // code to be used until it is expired. + // But this would be rare so we can ignore it for now. + if (usedCode.code.equals(code) && usedCode.isValid + && usedCode.expiryTime > System.currentTimeMillis()) { + isValid = false; + // We found a matching device but the code + // will be considered invalid here. + } } + } - return null; - } catch (TenantOrAppNotFoundException e) { + // Insert the code into the list of used codes: + + // If device is found, calculate used code expiry time for that device (based on + // its period and skew). Otherwise, use the max used code expiry time of all the + // devices. + int maxUsedCodeExpiry = Arrays.stream(devices) + .mapToInt(device -> device.period * (2 * device.skew + 1)) + .max() + .orElse(0); + int expireInSec = (matchingDevice != null) + ? matchingDevice.period * (2 * matchingDevice.skew + 1) + : maxUsedCodeExpiry; + + long now = System.currentTimeMillis(); + TOTPUsedCode newCode = new TOTPUsedCode(userId, + code, + isValid, now + 1000L * expireInSec, now); + try { + totpSQLStorage.insertUsedCode_Transaction(con, tenantIdentifierWithStorage, newCode); + totpSQLStorage.commitTransaction(con); + } catch (UnknownTotpUserIdException e) { throw new StorageTransactionLogicException(e); + } catch (UsedCodeAlreadyExistsException e) { + throw new StorageTransactionLogicException(new InvalidTotpException((int) invalidOutOfN, N)); } - }); - return; // exit the while loop - } catch (StorageTransactionLogicException e) { - // throwing errors will also help exit the while loop: - if (e.actualException instanceof TenantOrAppNotFoundException) { - throw (TenantOrAppNotFoundException) e.actualException; - } else if (e.actualException instanceof LimitReachedException) { - throw (LimitReachedException) e.actualException; - } else if (e.actualException instanceof InvalidTotpException) { - throw (InvalidTotpException) e.actualException; - } else if (e.actualException instanceof UnknownTotpUserIdException) { - throw (UnknownTotpUserIdException) e.actualException; - } else if (e.actualException instanceof UsedCodeAlreadyExistsException) { - throw new InvalidTotpException(); - } else { - throw e; + + if (!isValid) { + // transaction has been committed, so we can directly throw the exception: + throw new StorageTransactionLogicException(new InvalidTotpException((int)invalidOutOfN+1, N)); + } + + return null; + } catch (TenantOrAppNotFoundException e) { + throw new StorageTransactionLogicException(e); } + }); + return; // exit the while loop + } catch (StorageTransactionLogicException e) { + // throwing errors will also help exit the while loop: + if (e.actualException instanceof TenantOrAppNotFoundException) { + throw (TenantOrAppNotFoundException) e.actualException; + } else if (e.actualException instanceof LimitReachedException) { + throw (LimitReachedException) e.actualException; + } else if (e.actualException instanceof InvalidTotpException) { + throw (InvalidTotpException) e.actualException; + } else if (e.actualException instanceof UnknownTotpUserIdException) { + throw (UnknownTotpUserIdException) e.actualException; + } else { + throw e; } } } diff --git a/src/main/java/io/supertokens/totp/exceptions/InvalidTotpException.java b/src/main/java/io/supertokens/totp/exceptions/InvalidTotpException.java index 9dce2f51d..fc6dd25f2 100644 --- a/src/main/java/io/supertokens/totp/exceptions/InvalidTotpException.java +++ b/src/main/java/io/supertokens/totp/exceptions/InvalidTotpException.java @@ -1,5 +1,12 @@ package io.supertokens.totp.exceptions; public class InvalidTotpException extends Exception { + public int currentAttempts; + public int maxAttempts; + public InvalidTotpException(int currentAttempts, int maxAttempts) { + super("Invalid totp"); + this.currentAttempts = currentAttempts; + this.maxAttempts = maxAttempts; + } } diff --git a/src/main/java/io/supertokens/totp/exceptions/LimitReachedException.java b/src/main/java/io/supertokens/totp/exceptions/LimitReachedException.java index b7b1c8078..635aad73d 100644 --- a/src/main/java/io/supertokens/totp/exceptions/LimitReachedException.java +++ b/src/main/java/io/supertokens/totp/exceptions/LimitReachedException.java @@ -3,9 +3,13 @@ public class LimitReachedException extends Exception { public long retryAfterMs; + public int currentAttempts; + public int maxAttempts; - public LimitReachedException(long retryAfterMs) { + public LimitReachedException(long retryAfterMs, int currentAttempts, int maxAttempts) { super("Retry in " + retryAfterMs + " ms"); this.retryAfterMs = retryAfterMs; + this.currentAttempts = currentAttempts; + this.maxAttempts = maxAttempts; } } diff --git a/src/main/java/io/supertokens/webserver/api/totp/VerifyTotpAPI.java b/src/main/java/io/supertokens/webserver/api/totp/VerifyTotpAPI.java index b413332df..15e8db621 100644 --- a/src/main/java/io/supertokens/webserver/api/totp/VerifyTotpAPI.java +++ b/src/main/java/io/supertokens/webserver/api/totp/VerifyTotpAPI.java @@ -18,6 +18,7 @@ import io.supertokens.totp.exceptions.InvalidTotpException; import io.supertokens.totp.exceptions.LimitReachedException; import io.supertokens.useridmapping.UserIdType; +import io.supertokens.utils.SemVer; import io.supertokens.webserver.InputParser; import io.supertokens.webserver.WebserverAPI; import jakarta.servlet.ServletException; @@ -78,12 +79,20 @@ protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws I } catch (InvalidTotpException e) { result.addProperty("status", "INVALID_TOTP_ERROR"); super.sendJsonResponse(200, result, resp); + if (getVersionFromRequest(req).greaterThanOrEqualTo(SemVer.v4_1)) { + result.addProperty("currentNumberOfFailedAttempts", e.currentAttempts); + result.addProperty("maxNumberOfFailedAttempts", e.maxAttempts); + } } catch (UnknownTotpUserIdException e) { result.addProperty("status", "UNKNOWN_USER_ID_ERROR"); super.sendJsonResponse(200, result, resp); } catch (LimitReachedException e) { result.addProperty("status", "LIMIT_REACHED_ERROR"); result.addProperty("retryAfterMs", e.retryAfterMs); + if (getVersionFromRequest(req).greaterThanOrEqualTo(SemVer.v4_1)) { + result.addProperty("currentNumberOfFailedAttempts", e.currentAttempts); + result.addProperty("maxNumberOfFailedAttempts", e.maxAttempts); + } super.sendJsonResponse(200, result, resp); } catch (StorageQueryException | StorageTransactionLogicException | FeatureNotEnabledException | TenantOrAppNotFoundException e) { diff --git a/src/main/java/io/supertokens/webserver/api/totp/VerifyTotpDeviceAPI.java b/src/main/java/io/supertokens/webserver/api/totp/VerifyTotpDeviceAPI.java index da8c21c01..7a562d51b 100644 --- a/src/main/java/io/supertokens/webserver/api/totp/VerifyTotpDeviceAPI.java +++ b/src/main/java/io/supertokens/webserver/api/totp/VerifyTotpDeviceAPI.java @@ -17,6 +17,7 @@ import io.supertokens.totp.exceptions.InvalidTotpException; import io.supertokens.totp.exceptions.LimitReachedException; import io.supertokens.useridmapping.UserIdType; +import io.supertokens.utils.SemVer; import io.supertokens.webserver.InputParser; import io.supertokens.webserver.WebserverAPI; import jakarta.servlet.ServletException; @@ -83,10 +84,19 @@ protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws I super.sendJsonResponse(200, result, resp); } catch (InvalidTotpException e) { result.addProperty("status", "INVALID_TOTP_ERROR"); + + if (getVersionFromRequest(req).greaterThanOrEqualTo(SemVer.v4_1)) { + result.addProperty("currentNumberOfFailedAttempts", e.currentAttempts); + result.addProperty("maxNumberOfFailedAttempts", e.maxAttempts); + } super.sendJsonResponse(200, result, resp); } catch (LimitReachedException e) { result.addProperty("status", "LIMIT_REACHED_ERROR"); result.addProperty("retryAfterMs", e.retryAfterMs); + if (getVersionFromRequest(req).greaterThanOrEqualTo(SemVer.v4_1)) { + result.addProperty("currentNumberOfFailedAttempts", e.currentAttempts); + result.addProperty("maxNumberOfFailedAttempts", e.maxAttempts); + } super.sendJsonResponse(200, result, resp); } catch (StorageQueryException | StorageTransactionLogicException | TenantOrAppNotFoundException e) { throw new ServletException(e); diff --git a/src/test/java/io/supertokens/test/totp/TOTPRecipeTest.java b/src/test/java/io/supertokens/test/totp/TOTPRecipeTest.java index f926d2267..4afc4d279 100644 --- a/src/test/java/io/supertokens/test/totp/TOTPRecipeTest.java +++ b/src/test/java/io/supertokens/test/totp/TOTPRecipeTest.java @@ -55,8 +55,7 @@ import java.time.Instant; import java.util.Objects; -import static org.junit.Assert.assertNotNull; -import static org.junit.Assert.assertThrows; +import static org.junit.Assert.*; // TODO: Add test for UsedCodeAlreadyExistsException once we implement time mocking @@ -213,15 +212,18 @@ public void createDeviceAndVerifyCodeTest() throws Exception { () -> Totp.verifyCode(main, "user", newValidCode)); // Use a code from next period: + Thread.sleep(1); String nextValidCode = generateTotpCode(main, device, 1); Totp.verifyCode(main, "user", nextValidCode); // Use previous period code (should fail coz validCode has been used): + Thread.sleep(1); String previousCode = generateTotpCode(main, device, -1); assert previousCode.equals(validCode); assertThrows(InvalidTotpException.class, () -> Totp.verifyCode(main, "user", previousCode)); // Create device with skew = 0, check that it only works with the current code + Thread.sleep(1); TOTPDevice device2 = Totp.registerDevice(main, "user", "device2", 0, 1); assert !Objects.equals(device2.secretKey, device.secretKey); Totp.verifyDevice(main, "user", device2.deviceName, generateTotpCode(main, device2)); @@ -234,14 +236,17 @@ public void createDeviceAndVerifyCodeTest() throws Exception { () -> Totp.verifyCode(main, "user", nextValidCode2)); String previousValidCode2 = generateTotpCode(main, device2, -1); + Thread.sleep(1); assertThrows(InvalidTotpException.class, () -> Totp.verifyCode(main, "user", previousValidCode2)); + Thread.sleep(1); String currentValidCode2 = generateTotpCode(main, device2); Totp.verifyCode(main, "user", currentValidCode2); // Submit invalid code and check that it's expiry time is correct // created - expiryTime = max of ((2 * skew + 1) * period) for all devices + Thread.sleep(1); assertThrows(InvalidTotpException.class, () -> Totp.verifyCode(main, "user", "invalid")); @@ -288,6 +293,7 @@ public int triggerAndCheckRateLimit(Main main, TOTPDevice device) throws Excepti // This is to trigger rate limiting for (int i = 0; i < N; i++) { String code = "ic-" + i; // ic = invalid code + Thread.sleep(1); assertThrows( InvalidTotpException.class, () -> Totp.verifyCode(main, "user", code)); @@ -295,12 +301,15 @@ public int triggerAndCheckRateLimit(Main main, TOTPDevice device) throws Excepti // Any kind of attempt after this should fail with rate limiting error. // This should happen until rate limiting cooldown happens: + Thread.sleep(1); assertThrows( LimitReachedException.class, () -> Totp.verifyCode(main, "user", "icN+1")); + Thread.sleep(1); assertThrows( LimitReachedException.class, () -> Totp.verifyCode(main, "user", generateTotpCode(main, device))); + Thread.sleep(1); assertThrows( LimitReachedException.class, () -> Totp.verifyCode(main, "user", "icN+2")); @@ -402,7 +411,9 @@ public void createAndVerifyDeviceTest() throws Exception { // Verify device with wrong code assertThrows(InvalidTotpException.class, () -> Totp.verifyDevice(main, "user", "deviceName", "ic0")); + // Verify device with correct code + Thread.sleep(1); String validCode = generateTotpCode(main, device); boolean justVerfied = Totp.verifyDevice(main, "user", "deviceName", validCode); assert justVerfied; @@ -438,7 +449,9 @@ public void removeDeviceTest() throws Exception { TOTPDevice device1 = Totp.registerDevice(main, "user", "device1", 1, 30); TOTPDevice device2 = Totp.registerDevice(main, "user", "device2", 1, 30); + Thread.sleep(1); Totp.verifyDevice(main, "user", "device1", generateTotpCode(main, device1, -1)); + Thread.sleep(1); Totp.verifyDevice(main, "user", "device2", generateTotpCode(main, device2, -1)); TOTPDevice[] devices = Totp.getDevices(main, "user"); @@ -456,7 +469,9 @@ public void removeDeviceTest() throws Exception { Thread.sleep(1000 - System.currentTimeMillis() % 1000 + 10); + Thread.sleep(1); Totp.verifyCode(main, "user", generateTotpCode(main, device1)); + Thread.sleep(1); Totp.verifyCode(main, "user", generateTotpCode(main, device2)); // Delete device1 @@ -477,6 +492,7 @@ public void removeDeviceTest() throws Exception { // Create another user to test that other users aren't affected: TOTPDevice otherUserDevice = Totp.registerDevice(main, "other-user", "device", 1, 30); Totp.verifyDevice(main, "other-user", "device", generateTotpCode(main, otherUserDevice, -1)); + Thread.sleep(1); Totp.verifyCode(main, "other-user", generateTotpCode(main, otherUserDevice)); assertThrows(InvalidTotpException.class, () -> Totp.verifyCode(main, "other-user", "ic1")); @@ -580,4 +596,72 @@ public void testRegisterDeviceWithSameNameAsAnUnverifiedDevice() throws Exceptio Totp.registerDevice(main, "user", "device1", 1, 30); Totp.registerDevice(main, "user", "device1", 1, 30); } + + @Test + public void testCurrentAndMaxAttemptsInExceptions() throws Exception { + String[] args = {"../"}; + + TestingProcessManager.TestingProcess process = TestingProcessManager.start(args); + assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STARTED)); + + if (StorageLayer.getStorage(process.getProcess()).getType() != STORAGE_TYPE.SQL) { + return; + } + + TOTPDevice device = Totp.registerDevice(process.getProcess(), "userId", "deviceName", 1, 30); + try { + Totp.verifyDevice(process.getProcess(), "userId", "deviceName", "123456"); + fail(); + } catch (InvalidTotpException e) { + assertEquals(1, e.currentAttempts); + assertEquals(5, e.maxAttempts); + } + Thread.sleep(1); + try { + Totp.verifyDevice(process.getProcess(), "userId", "deviceName", "223456"); + fail(); + } catch (InvalidTotpException e) { + assertEquals(2, e.currentAttempts); + assertEquals(5, e.maxAttempts); + } + Thread.sleep(1); + + try { + Totp.verifyDevice(process.getProcess(), "userId", "deviceName", "323456"); + fail(); + } catch (InvalidTotpException e) { + assertEquals(3, e.currentAttempts); + assertEquals(5, e.maxAttempts); + } + Thread.sleep(1); + + try { + Totp.verifyDevice(process.getProcess(), "userId", "deviceName", "423456"); + fail(); + } catch (InvalidTotpException e) { + assertEquals(4, e.currentAttempts); + assertEquals(5, e.maxAttempts); + } + Thread.sleep(1); + + try { + Totp.verifyDevice(process.getProcess(), "userId", "deviceName", "523456"); + fail(); + } catch (InvalidTotpException e) { + assertEquals(5, e.currentAttempts); + assertEquals(5, e.maxAttempts); + } + Thread.sleep(1); + + try { + Totp.verifyDevice(process.getProcess(), "userId", "deviceName", "623456"); + fail(); + } catch (LimitReachedException e) { + assertEquals(5, e.currentAttempts); + assertEquals(5, e.maxAttempts); + } + + process.kill(); + assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STOPPED)); + } } diff --git a/src/test/java/io/supertokens/test/totp/TotpLicenseTest.java b/src/test/java/io/supertokens/test/totp/TotpLicenseTest.java index 889eb02f6..9dfda6763 100644 --- a/src/test/java/io/supertokens/test/totp/TotpLicenseTest.java +++ b/src/test/java/io/supertokens/test/totp/TotpLicenseTest.java @@ -171,6 +171,7 @@ public void testTotpWithLicense() throws Exception { String code = generateTotpCode(main, device, 0); Totp.verifyDevice(main, device.userId, device.deviceName, code); // Verify code + Thread.sleep(1); String nextCode = generateTotpCode(main, device, 1); Totp.verifyCode(main, "user", nextCode); } diff --git a/src/test/java/io/supertokens/test/totp/api/VerifyTotpAPITest.java b/src/test/java/io/supertokens/test/totp/api/VerifyTotpAPITest.java index ffc556d24..31018ba57 100644 --- a/src/test/java/io/supertokens/test/totp/api/VerifyTotpAPITest.java +++ b/src/test/java/io/supertokens/test/totp/api/VerifyTotpAPITest.java @@ -191,7 +191,7 @@ public void testApi() throws Exception { assert res3.get("retryAfterMs") != null; // wait for cooldown to end (1s) - Thread.sleep(1200); + Thread.sleep(1300); // should pass now on valid code String validTotp = generateTotpCode(process.getProcess(), device); diff --git a/src/test/java/io/supertokens/test/totp/api/VerifyTotpDeviceAPITest.java b/src/test/java/io/supertokens/test/totp/api/VerifyTotpDeviceAPITest.java index 82c18b883..8a55255c9 100644 --- a/src/test/java/io/supertokens/test/totp/api/VerifyTotpDeviceAPITest.java +++ b/src/test/java/io/supertokens/test/totp/api/VerifyTotpDeviceAPITest.java @@ -162,7 +162,10 @@ public void testApi() throws Exception { null, Utils.getCdiVersionStringLatestForTests(), "totp"); + assertEquals(3, res0.entrySet().size()); assert res0.get("status").getAsString().equals("INVALID_TOTP_ERROR"); + assertEquals(1, res0.get("currentNumberOfFailedAttempts").getAsInt()); + assertEquals(1, res0.get("maxNumberOfFailedAttempts").getAsInt()); // Check that rate limiting is triggered for the user: JsonObject res3 = HttpRequestForTesting.sendJsonPOSTRequest(