diff --git a/.circleci/config.yml b/.circleci/config.yml index f7b3ad0..0e91623 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -69,6 +69,73 @@ jobs: name: running tests command: (cd .circleci/ && ./doTests.sh) - slack/status + test-onemillionusers: + docker: + - image: rishabhpoddar/supertokens_core_testing + resource_class: large + steps: + - checkout + - run: echo $'\n[mysqld]\ncharacter_set_server=utf8mb4\nmax_connections=10000' >> /etc/mysql/mysql.cnf + - run: + name: starting mysql + command: | + (cd / && ./runMySQL.sh) + mysql -u root --password=root -e "CREATE DATABASE supertokens;" + mysql -u root --password=root -e "CREATE DATABASE st0;" + mysql -u root --password=root -e "CREATE DATABASE st1;" + mysql -u root --password=root -e "CREATE DATABASE st2;" + mysql -u root --password=root -e "CREATE DATABASE st3;" + mysql -u root --password=root -e "CREATE DATABASE st4;" + mysql -u root --password=root -e "CREATE DATABASE st5;" + mysql -u root --password=root -e "CREATE DATABASE st6;" + mysql -u root --password=root -e "CREATE DATABASE st7;" + mysql -u root --password=root -e "CREATE DATABASE st8;" + mysql -u root --password=root -e "CREATE DATABASE st9;" + mysql -u root --password=root -e "CREATE DATABASE st10;" + mysql -u root --password=root -e "CREATE DATABASE st11;" + mysql -u root --password=root -e "CREATE DATABASE st12;" + mysql -u root --password=root -e "CREATE DATABASE st13;" + mysql -u root --password=root -e "CREATE DATABASE st14;" + mysql -u root --password=root -e "CREATE DATABASE st15;" + mysql -u root --password=root -e "CREATE DATABASE st16;" + mysql -u root --password=root -e "CREATE DATABASE st17;" + mysql -u root --password=root -e "CREATE DATABASE st18;" + mysql -u root --password=root -e "CREATE DATABASE st19;" + mysql -u root --password=root -e "CREATE DATABASE st20;" + mysql -u root --password=root -e "CREATE DATABASE st21;" + mysql -u root --password=root -e "CREATE DATABASE st22;" + mysql -u root --password=root -e "CREATE DATABASE st23;" + mysql -u root --password=root -e "CREATE DATABASE st24;" + mysql -u root --password=root -e "CREATE DATABASE st25;" + mysql -u root --password=root -e "CREATE DATABASE st26;" + mysql -u root --password=root -e "CREATE DATABASE st27;" + mysql -u root --password=root -e "CREATE DATABASE st28;" + mysql -u root --password=root -e "CREATE DATABASE st29;" + mysql -u root --password=root -e "CREATE DATABASE st30;" + mysql -u root --password=root -e "CREATE DATABASE st31;" + mysql -u root --password=root -e "CREATE DATABASE st32;" + mysql -u root --password=root -e "CREATE DATABASE st33;" + mysql -u root --password=root -e "CREATE DATABASE st34;" + mysql -u root --password=root -e "CREATE DATABASE st35;" + mysql -u root --password=root -e "CREATE DATABASE st36;" + mysql -u root --password=root -e "CREATE DATABASE st37;" + mysql -u root --password=root -e "CREATE DATABASE st38;" + mysql -u root --password=root -e "CREATE DATABASE st39;" + mysql -u root --password=root -e "CREATE DATABASE st40;" + mysql -u root --password=root -e "CREATE DATABASE st41;" + mysql -u root --password=root -e "CREATE DATABASE st42;" + mysql -u root --password=root -e "CREATE DATABASE st43;" + mysql -u root --password=root -e "CREATE DATABASE st44;" + mysql -u root --password=root -e "CREATE DATABASE st45;" + mysql -u root --password=root -e "CREATE DATABASE st46;" + mysql -u root --password=root -e "CREATE DATABASE st47;" + mysql -u root --password=root -e "CREATE DATABASE st48;" + mysql -u root --password=root -e "CREATE DATABASE st49;" + mysql -u root --password=root -e "CREATE DATABASE st50;" + - run: + name: running tests + command: (cd .circleci/ && ./doOneMillionUsersTests.sh) + - slack/status workflows: version: 2 @@ -81,4 +148,23 @@ workflows: tags: only: /dev-v[0-9]+(\.[0-9]+)*/ branches: - ignore: /.*/ \ No newline at end of file + ignore: /.*/ + - test-onemillionusers: + context: + - slack-notification + filters: + tags: + only: /dev-v[0-9]+(\.[0-9]+)*/ + branches: + ignore: /.*/ + - mark-passed: + context: + - slack-notification + filters: + tags: + only: /dev-v[0-9]+(\.[0-9]+)*/ + branches: + ignore: /.*/ + requires: + - test + - test-onemillionusers diff --git a/.circleci/doOneMillionUsersTests.sh b/.circleci/doOneMillionUsersTests.sh new file mode 100755 index 0000000..112e76f --- /dev/null +++ b/.circleci/doOneMillionUsersTests.sh @@ -0,0 +1,135 @@ +function cleanup { + if test -f "pluginInterfaceExactVersionsOutput"; then + rm pluginInterfaceExactVersionsOutput + fi +} + +trap cleanup EXIT +cleanup + +pluginInterfaceJson=`cat ../pluginInterfaceSupported.json` +pluginInterfaceLength=`echo $pluginInterfaceJson | jq ".versions | length"` +pluginInterfaceArray=`echo $pluginInterfaceJson | jq ".versions"` +echo "got plugin interface relations" + +./getPluginInterfaceExactVersions.sh $pluginInterfaceLength "$pluginInterfaceArray" + +if [[ $? -ne 0 ]] +then + echo "all plugin interfaces found... failed. exiting!" + exit 1 +else + echo "all plugin interfaces found..." +fi + +# get plugin version +pluginVersion=`cat ../build.gradle | grep -e "version =" -e "version="` +while IFS='"' read -ra ADDR; do + counter=0 + for i in "${ADDR[@]}"; do + if [ $counter == 1 ] + then + pluginVersion=$i + fi + counter=$(($counter+1)) + done +done <<< "$pluginVersion" + +responseStatus=`curl -s -o /dev/null -w "%{http_code}" -X PUT \ + https://api.supertokens.io/0/plugin \ + -H 'Content-Type: application/json' \ + -H 'api-version: 0' \ + -d "{ + \"password\": \"$SUPERTOKENS_API_KEY\", + \"planType\":\"FREE\", + \"version\":\"$pluginVersion\", + \"pluginInterfaces\": $pluginInterfaceArray, + \"name\": \"mysql\" +}"` +if [ $responseStatus -ne "200" ] +then + echo "failed plugin PUT API status code: $responseStatus. Exiting!" + exit 1 +fi + +someTestsRan=false +while read -u 10 line +do + if [[ $line = "" ]]; then + continue + fi + i=0 + currTag=`echo $line | jq .tag` + currTag=`echo $currTag | tr -d '"'` + + currVersion=`echo $line | jq .version` + currVersion=`echo $currVersion | tr -d '"'` + piX=$(cut -d'.' -f1 <<<"$currVersion") + piY=$(cut -d'.' -f2 <<<"$currVersion") + piVersion="$piX.$piY" + + someTestsRan=true + + response=`curl -s -X GET \ + "https://api.supertokens.io/0/plugin-interface/dependency/core/latest?password=$SUPERTOKENS_API_KEY&planType=FREE&mode=DEV&version=$piVersion" \ + -H 'api-version: 0'` + if [[ `echo $response | jq .core` == "null" ]] + then + echo "fetching latest X.Y version for core given plugin-interface X.Y version: $piVersion gave response: $response" + exit 1 + fi + coreVersionX2=$(echo $response | jq .core | tr -d '"') + + response=`curl -s -X GET \ + "https://api.supertokens.io/0/core/latest?password=$SUPERTOKENS_API_KEY&planType=FREE&mode=DEV&version=$coreVersionX2" \ + -H 'api-version: 0'` + if [[ `echo $response | jq .tag` == "null" ]] + then + echo "fetching latest X.Y.Z version for core X.Y version: $coreVersionX2 gave response: $response" + exit 1 + fi + coreVersionTag=$(echo $response | jq .tag | tr -d '"') + + cd ../../ + git clone git@github.com:supertokens/supertokens-root.git + cd supertokens-root + + update-alternatives --install "/usr/bin/java" "java" "/usr/java/jdk-15.0.1/bin/java" 2 + update-alternatives --install "/usr/bin/javac" "javac" "/usr/java/jdk-15.0.1/bin/javac" 2 + + pluginX=$(cut -d'.' -f1 <<<"$pluginVersion") + pluginY=$(cut -d'.' -f2 <<<"$pluginVersion") + echo -e "core,$coreVersionX2\nplugin-interface,$piVersion\nmysql-plugin,$pluginX.$pluginY" > modules.txt + ./loadModules + cd supertokens-core + git checkout $coreVersionTag + cd ../supertokens-plugin-interface + git checkout $currTag + cd ../supertokens-mysql-plugin + git checkout dev-v$pluginVersion + cd ../ + echo $SUPERTOKENS_API_KEY > apiPassword + export ONE_MILLION_USERS_TEST=1 + ./utils/setupTestEnv --cicd + ./gradlew :supertokens-postgresql-plugin:test --tests io.supertokens.storage.mysql.test.OneMillionUsersTest + + if [[ $? -ne 0 ]] + then + cat logs/* + cd ../project/ + echo "test failed... exiting!" + exit 1 + fi + cd ../ + rm -rf supertokens-root + cd project/.circleci +done 10 { try { long now = System.currentTimeMillis(); + Connection sqlCon = (Connection) con.getConnection(); + TOTPQueries.createDevice_Transaction(this, sqlCon, tenantIdentifier.toAppIdentifier(), device); TOTPQueries.insertUsedCode_Transaction(this, - (Connection) con.getConnection(), tenantIdentifier, new TOTPUsedCode(userId, "123456", true, 1000+now, now)); + sqlCon, tenantIdentifier, new TOTPUsedCode(userId, "123456", true, 1000+now, now)); } catch (SQLException e) { throw new StorageTransactionLogicException(e); } @@ -1288,26 +1289,7 @@ public int countUsersActiveSince(AppIdentifier appIdentifier, long time) throws } } - @Override - public int countUsersEnabledTotp(AppIdentifier appIdentifier) throws StorageQueryException { - try { - return ActiveUsersQueries.countUsersEnabledTotp(this, appIdentifier); - } catch (SQLException e) { - throw new StorageQueryException(e); - } - } - - @Override - public int countUsersEnabledTotpAndActiveSince(AppIdentifier appIdentifier, long time) - throws StorageQueryException { - try { - return ActiveUsersQueries.countUsersEnabledTotpAndActiveSince(this, appIdentifier, time); - } catch (SQLException e) { - throw new StorageQueryException(e); - } - } - - @Override + @Override public void deleteUserActive_Transaction(TransactionConnection con, AppIdentifier appIdentifier, String userId) throws StorageQueryException { try { @@ -2193,10 +2175,11 @@ public boolean updateOrDeleteExternalUserIdInfo(AppIdentifier appIdentifier, Str } @Override - public HashMap getUserIdMappingForSuperTokensIds(ArrayList userIds) + public HashMap getUserIdMappingForSuperTokensIds(AppIdentifier appIdentifier, + ArrayList userIds) throws StorageQueryException { try { - return UserIdMappingQueries.getUserIdMappingWithUserIds(this, userIds); + return UserIdMappingQueries.getUserIdMappingWithUserIds(this, appIdentifier, userIds); } catch (SQLException e) { throw new StorageQueryException(e); } @@ -2562,24 +2545,59 @@ public void revokeExpiredSessions() throws StorageQueryException { } // TOTP recipe: + @TestOnly @Override public void createDevice(AppIdentifier appIdentifier, TOTPDevice device) throws StorageQueryException, DeviceAlreadyExistsException, TenantOrAppNotFoundException { try { - TOTPQueries.createDevice(this, appIdentifier, device); + startTransaction(con -> { + try { + createDevice_Transaction(con, new AppIdentifier(null, null), device); + } catch (DeviceAlreadyExistsException | TenantOrAppNotFoundException e) { + throw new StorageTransactionLogicException(e); + } + return null; + }); } catch (StorageTransactionLogicException e) { - if (e.actualException instanceof SQLIntegrityConstraintViolationException) { - String errMsg = e.actualException.getMessage(); + if (e.actualException instanceof DeviceAlreadyExistsException) { + throw (DeviceAlreadyExistsException) e.actualException; + } else if (e.actualException instanceof TenantOrAppNotFoundException) { + throw (TenantOrAppNotFoundException) e.actualException; + } else if (e.actualException instanceof StorageQueryException) { + throw (StorageQueryException) e.actualException; + } + } + } + + @Override + public TOTPDevice createDevice_Transaction(TransactionConnection con, AppIdentifier appIdentifier, TOTPDevice device) + throws DeviceAlreadyExistsException, TenantOrAppNotFoundException, StorageQueryException { + Connection sqlCon = (Connection) con.getConnection(); + try { + TOTPQueries.createDevice_Transaction(this, sqlCon, appIdentifier, device); + return device; + } catch (SQLException actualException) { + if (actualException instanceof SQLIntegrityConstraintViolationException) { + String errMsg = actualException.getMessage(); if (isPrimaryKeyError(errMsg, Config.getConfig(this).getTotpUserDevicesTable())) { throw new DeviceAlreadyExistsException(); } else if (isForeignKeyConstraintError(errMsg, Config.getConfig(this).getTotpUsersTable(), "app_id")) { throw new TenantOrAppNotFoundException(appIdentifier); } - } - throw new StorageQueryException(e.actualException); + throw new StorageQueryException(actualException); + } + } + + @Override + public TOTPDevice getDeviceByName_Transaction(TransactionConnection con, AppIdentifier appIdentifier, String userId, String deviceName) throws StorageQueryException { + Connection sqlCon = (Connection) con.getConnection(); + try { + return TOTPQueries.getDeviceByName_Transaction(this, sqlCon, appIdentifier, userId, deviceName); + } catch (SQLException e) { + throw new StorageQueryException(e); } } @@ -2675,7 +2693,7 @@ public TOTPDevice[] getDevices_Transaction(TransactionConnection con, AppIdentif @Override public void insertUsedCode_Transaction(TransactionConnection con, TenantIdentifier tenantIdentifier, TOTPUsedCode usedCodeObj) - throws StorageQueryException, TotpNotEnabledException, UsedCodeAlreadyExistsException, + throws StorageQueryException, UnknownTotpUserIdException, UsedCodeAlreadyExistsException, TenantOrAppNotFoundException { Connection sqlCon = (Connection) con.getConnection(); try { @@ -2685,7 +2703,7 @@ public void insertUsedCode_Transaction(TransactionConnection con, TenantIdentifi throw new UsedCodeAlreadyExistsException(); } else if (isForeignKeyConstraintError(e.getMessage(), Config.getConfig(this).getTotpUsedCodesTable(), "user_id")) { - throw new TotpNotEnabledException(); + throw new UnknownTotpUserIdException(); } else if (isForeignKeyConstraintError(e.getMessage(), Config.getConfig(this).getTotpUsedCodesTable(), "tenant_id")) { throw new TenantOrAppNotFoundException(tenantIdentifier); } @@ -2933,6 +2951,24 @@ public UserIdMapping[] getUserIdMapping_Transaction(TransactionConnection con, A } } + @Override + public int getUsersCountWithMoreThanOneLoginMethodOrTOTPEnabled(AppIdentifier appIdentifier) throws StorageQueryException { + try { + return GeneralQueries.getUsersCountWithMoreThanOneLoginMethodOrTOTPEnabled(this, appIdentifier); + } catch (SQLException e) { + throw new StorageQueryException(e); + } + } + + @Override + public int countUsersThatHaveMoreThanOneLoginMethodOrTOTPEnabledAndActiveSince(AppIdentifier appIdentifier, long sinceTime) throws StorageQueryException { + try { + return ActiveUsersQueries.countUsersThatHaveMoreThanOneLoginMethodOrTOTPEnabledAndActiveSince(this, appIdentifier, sinceTime); + } catch (SQLException e) { + throw new StorageQueryException(e); + } + } + public static boolean isEnabledForDeadlockTesting() { return enableForDeadlockTesting; } diff --git a/src/main/java/io/supertokens/storage/mysql/config/MySQLConfig.java b/src/main/java/io/supertokens/storage/mysql/config/MySQLConfig.java index 133dc0b..52bc3fe 100644 --- a/src/main/java/io/supertokens/storage/mysql/config/MySQLConfig.java +++ b/src/main/java/io/supertokens/storage/mysql/config/MySQLConfig.java @@ -184,6 +184,14 @@ public String getTenantConfigsTable() { return addPrefixToTableName("tenant_configs"); } + public String getTenantFirstFactorsTable() { + return addPrefixToTableName("tenant_first_factors"); + } + + public String getTenantRequiredSecondaryFactorsTable() { + return addPrefixToTableName("tenant_required_secondary_factors"); + } + public String getTenantThirdPartyProvidersTable() { return addPrefixToTableName("tenant_thirdparty_providers"); } diff --git a/src/main/java/io/supertokens/storage/mysql/queries/ActiveUsersQueries.java b/src/main/java/io/supertokens/storage/mysql/queries/ActiveUsersQueries.java index 289980a..1905ec2 100644 --- a/src/main/java/io/supertokens/storage/mysql/queries/ActiveUsersQueries.java +++ b/src/main/java/io/supertokens/storage/mysql/queries/ActiveUsersQueries.java @@ -137,4 +137,41 @@ public static void deleteUserActive_Transaction(Connection con, Start start, App pst.setString(2, userId); }); } + + public static int countUsersThatHaveMoreThanOneLoginMethodOrTOTPEnabledAndActiveSince(Start start, AppIdentifier appIdentifier, long sinceTime) + throws SQLException, StorageQueryException { + // TODO: Active users are present only on public tenant and MFA users may be present on different storages + String QUERY = + "SELECT COUNT(user_id) as c FROM (" + + " " // users with more than one login method + + " SELECT primary_or_recipe_user_id AS user_id FROM (" + + " SELECT COUNT(user_id) as num_login_methods, app_id, primary_or_recipe_user_id" + + " FROM " + Config.getConfig(start).getAppIdToUserIdTable() + + " WHERE app_id = ? AND primary_or_recipe_user_id IN (" + + " SELECT user_id FROM " + Config.getConfig(start).getUserLastActiveTable() + + " WHERE app_id = ? AND last_active_time >= ?" + + " )" + + " GROUP BY app_id, primary_or_recipe_user_id" + + " ) AS nloginmethods" + + " WHERE num_login_methods > 1" + + " UNION" // TOTP users + + " SELECT user_id FROM " + Config.getConfig(start).getTotpUsersTable() + + " WHERE app_id = ? AND user_id IN (" + + " SELECT user_id FROM " + Config.getConfig(start).getUserLastActiveTable() + + " WHERE app_id = ? AND last_active_time >= ?" + + " )" + + " " + + ") AS all_users"; + + return execute(start, QUERY, pst -> { + pst.setString(1, appIdentifier.getAppId()); + pst.setString(2, appIdentifier.getAppId()); + pst.setLong(3, sinceTime); + pst.setString(4, appIdentifier.getAppId()); + pst.setString(5, appIdentifier.getAppId()); + pst.setLong(6, sinceTime); + }, result -> { + return result.next() ? result.getInt("c") : 0; + }); + } } diff --git a/src/main/java/io/supertokens/storage/mysql/queries/EmailVerificationQueries.java b/src/main/java/io/supertokens/storage/mysql/queries/EmailVerificationQueries.java index eee5a00..5941063 100644 --- a/src/main/java/io/supertokens/storage/mysql/queries/EmailVerificationQueries.java +++ b/src/main/java/io/supertokens/storage/mysql/queries/EmailVerificationQueries.java @@ -284,7 +284,7 @@ public static List isEmailVerified_transaction(Start start, Connection s // calculating the verified emails HashMap supertokensUserIdToExternalUserIdMap = UserIdMappingQueries.getUserIdMappingWithUserIds_Transaction(start, - sqlCon, supertokensUserIds); + sqlCon, appIdentifier, supertokensUserIds); HashMap externalUserIdToSupertokensUserIdMap = new HashMap<>(); List supertokensOrExternalUserIdsToQuery = new ArrayList<>(); @@ -352,7 +352,7 @@ public static List isEmailVerified(Start start, AppIdentifier appIdentif // We have external user id stored in the email verification table, so we need to fetch the mapped userids for // calculating the verified emails HashMap supertokensUserIdToExternalUserIdMap = UserIdMappingQueries.getUserIdMappingWithUserIds(start, - supertokensUserIds); + appIdentifier, supertokensUserIds); HashMap externalUserIdToSupertokensUserIdMap = new HashMap<>(); List supertokensOrExternalUserIdsToQuery = new ArrayList<>(); for (String userId : supertokensUserIds) { diff --git a/src/main/java/io/supertokens/storage/mysql/queries/GeneralQueries.java b/src/main/java/io/supertokens/storage/mysql/queries/GeneralQueries.java index ba88f54..8e16b12 100644 --- a/src/main/java/io/supertokens/storage/mysql/queries/GeneralQueries.java +++ b/src/main/java/io/supertokens/storage/mysql/queries/GeneralQueries.java @@ -263,6 +263,16 @@ public static void createTablesIfNotExists(Start start) throws SQLException, Sto update(start, MultitenancyQueries.getQueryToCreateTenantConfigsTable(start), NO_OP_SETTER); } + if (!doesTableExists(start, Config.getConfig(start).getTenantFirstFactorsTable())) { + getInstance(start).addState(CREATING_NEW_TABLE, null); + update(start, MultitenancyQueries.getQueryToCreateFirstFactorsTable(start), NO_OP_SETTER); + } + + if (!doesTableExists(start, Config.getConfig(start).getTenantRequiredSecondaryFactorsTable())) { + getInstance(start).addState(CREATING_NEW_TABLE, null); + update(start, MultitenancyQueries.getQueryToCreateRequiredSecondaryFactorsTable(start), NO_OP_SETTER); + } + if (!doesTableExists(start, Config.getConfig(start).getTenantThirdPartyProvidersTable())) { getInstance(start).addState(CREATING_NEW_TABLE, null); update(start, MultitenancyQueries.getQueryToCreateTenantThirdPartyProvidersTable(start), @@ -1561,6 +1571,32 @@ public static int getUsersCountWithMoreThanOneLoginMethod(Start start, AppIdenti }); } + public static int getUsersCountWithMoreThanOneLoginMethodOrTOTPEnabled(Start start, AppIdentifier appIdentifier) + throws SQLException, StorageQueryException { + String QUERY = + "SELECT COUNT(user_id) as c FROM (" + + " " // Users with number of login methods > 1 + + " SELECT primary_or_recipe_user_id AS user_id FROM (" + + " SELECT COUNT(user_id) as num_login_methods, app_id, primary_or_recipe_user_id" + + " FROM " + Config.getConfig(start).getAppIdToUserIdTable() + + " WHERE app_id = ? " + + " GROUP BY app_id, primary_or_recipe_user_id" + + " ) AS nloginmethods" + + " WHERE num_login_methods > 1" + + " UNION" // TOTP users + + " SELECT user_id FROM " + Config.getConfig(start).getTotpUsersTable() + + " WHERE app_id = ?" + + " " + + ") AS all_users"; + + return execute(start, QUERY, pst -> { + pst.setString(1, appIdentifier.getAppId()); + pst.setString(2, appIdentifier.getAppId()); + }, result -> { + return result.next() ? result.getInt("c") : 0; + }); + } + public static boolean checkIfUsesAccountLinking(Start start, AppIdentifier appIdentifier) throws SQLException, StorageQueryException { String QUERY = "SELECT 1 FROM " diff --git a/src/main/java/io/supertokens/storage/mysql/queries/MultitenancyQueries.java b/src/main/java/io/supertokens/storage/mysql/queries/MultitenancyQueries.java index edacff4..53923d1 100644 --- a/src/main/java/io/supertokens/storage/mysql/queries/MultitenancyQueries.java +++ b/src/main/java/io/supertokens/storage/mysql/queries/MultitenancyQueries.java @@ -27,6 +27,7 @@ import io.supertokens.pluginInterface.multitenancy.exceptions.TenantOrAppNotFoundException; import io.supertokens.storage.mysql.Start; import io.supertokens.storage.mysql.config.Config; +import io.supertokens.storage.mysql.queries.multitenancy.MfaSqlHelper; import io.supertokens.storage.mysql.queries.multitenancy.TenantConfigSQLHelper; import io.supertokens.storage.mysql.queries.multitenancy.ThirdPartyProviderClientSQLHelper; import io.supertokens.storage.mysql.queries.multitenancy.ThirdPartyProviderSQLHelper; @@ -107,8 +108,38 @@ static String getQueryToCreateTenantThirdPartyProviderClientsTable(Start start) + ");"; } + public static String getQueryToCreateFirstFactorsTable(Start start) { + String tableName = Config.getConfig(start).getTenantFirstFactorsTable(); + // @formatter:off + return "CREATE TABLE IF NOT EXISTS " + tableName + " (" + + "connection_uri_domain VARCHAR(256) DEFAULT ''," + + "app_id VARCHAR(64) DEFAULT 'public'," + + "tenant_id VARCHAR(64) DEFAULT 'public'," + + "factor_id VARCHAR(128)," + + "PRIMARY KEY (connection_uri_domain, app_id, tenant_id, factor_id)," + + "FOREIGN KEY (connection_uri_domain, app_id, tenant_id)" + + " REFERENCES " + Config.getConfig(start).getTenantConfigsTable() + " (connection_uri_domain, app_id, tenant_id) ON DELETE CASCADE" + + ");"; + // @formatter:on + } + + public static String getQueryToCreateRequiredSecondaryFactorsTable(Start start) { + String tableName = Config.getConfig(start).getTenantRequiredSecondaryFactorsTable(); + // @formatter:off + return "CREATE TABLE IF NOT EXISTS " + tableName + " (" + + "connection_uri_domain VARCHAR(256) DEFAULT ''," + + "app_id VARCHAR(64) DEFAULT 'public'," + + "tenant_id VARCHAR(64) DEFAULT 'public'," + + "factor_id VARCHAR(128)," + + "PRIMARY KEY (connection_uri_domain, app_id, tenant_id, factor_id)," + + "FOREIGN KEY (connection_uri_domain, app_id, tenant_id)" + + " REFERENCES " + Config.getConfig(start).getTenantConfigsTable() + + " (connection_uri_domain, app_id, tenant_id) ON DELETE CASCADE);"; + // @formatter:on + } + private static void executeCreateTenantQueries(Start start, Connection sqlCon, TenantConfig tenantConfig) - throws SQLException, StorageTransactionLogicException { + throws SQLException, StorageTransactionLogicException, StorageQueryException { try { TenantConfigSQLHelper.create(start, sqlCon, tenantConfig); @@ -143,6 +174,9 @@ private static void executeCreateTenantQueries(Start start, Connection sqlCon, T } } } + + MfaSqlHelper.createFirstFactors(start, sqlCon, tenantConfig.tenantIdentifier, tenantConfig.firstFactors); + MfaSqlHelper.createRequiredSecondaryFactors(start, sqlCon, tenantConfig.tenantIdentifier, tenantConfig.requiredSecondaryFactors); } public static void createTenantConfig(Start start, TenantConfig tenantConfig) throws StorageQueryException, StorageTransactionLogicException { @@ -221,7 +255,13 @@ public static TenantConfig[] getAllTenants(Start start) throws StorageQueryExcep // Map (tenantIdentifier) -> thirdPartyId -> provider HashMap> providerMap = ThirdPartyProviderSQLHelper.selectAll(start, providerClientsMap); - return TenantConfigSQLHelper.selectAll(start, providerMap); + // Map (tenantIdentifier) -> firstFactors + HashMap firstFactorsMap = MfaSqlHelper.selectAllFirstFactors(start); + + // Map (tenantIdentifier) -> requiredSecondaryFactors + HashMap requiredSecondaryFactorsMap = MfaSqlHelper.selectAllRequiredSecondaryFactors(start); + + return TenantConfigSQLHelper.selectAll(start, providerMap, firstFactorsMap, requiredSecondaryFactorsMap); } catch (SQLException throwables) { throw new StorageQueryException(throwables); } diff --git a/src/main/java/io/supertokens/storage/mysql/queries/SessionQueries.java b/src/main/java/io/supertokens/storage/mysql/queries/SessionQueries.java index ae8090b..12e3cca 100644 --- a/src/main/java/io/supertokens/storage/mysql/queries/SessionQueries.java +++ b/src/main/java/io/supertokens/storage/mysql/queries/SessionQueries.java @@ -141,18 +141,19 @@ public static SessionInfo getSessionInfo_Transaction(Start start, Connection con public static void updateSessionInfo_Transaction(Start start, Connection con, TenantIdentifier tenantIdentifier, String sessionHandle, - String refreshTokenHash2, long expiry) + String refreshTokenHash2, long expiry, boolean useStaticKey) throws SQLException, StorageQueryException { String QUERY = "UPDATE " + Config.getConfig(start).getSessionInfoTable() - + " SET refresh_token_hash_2 = ?, expires_at = ?" + + " SET refresh_token_hash_2 = ?, expires_at = ?, use_static_key= ?" + " WHERE app_id = ? AND tenant_id = ? AND session_handle = ?"; update(con, QUERY, pst -> { pst.setString(1, refreshTokenHash2); pst.setLong(2, expiry); - pst.setString(3, tenantIdentifier.getAppId()); - pst.setString(4, tenantIdentifier.getTenantId()); - pst.setString(5, sessionHandle); + pst.setBoolean(3, useStaticKey); + pst.setString(4, tenantIdentifier.getAppId()); + pst.setString(5, tenantIdentifier.getTenantId()); + pst.setString(6, sessionHandle); }); } diff --git a/src/main/java/io/supertokens/storage/mysql/queries/TOTPQueries.java b/src/main/java/io/supertokens/storage/mysql/queries/TOTPQueries.java index dadbf0d..b3536dd 100644 --- a/src/main/java/io/supertokens/storage/mysql/queries/TOTPQueries.java +++ b/src/main/java/io/supertokens/storage/mysql/queries/TOTPQueries.java @@ -39,6 +39,7 @@ public static String getQueryToCreateUserDevicesTable(Start start) { + "period INTEGER NOT NULL," + "skew INTEGER NOT NULL," + "verified BOOLEAN NOT NULL," + + "created_at BIGINT UNSIGNED NOT NULL," + "PRIMARY KEY (app_id, user_id, device_name)," + "FOREIGN KEY (app_id, user_id)" + " REFERENCES " + Config.getConfig(start).getTotpUsersTable() + "(app_id, user_id) ON DELETE CASCADE" @@ -88,7 +89,7 @@ private static int insertUser_Transaction(Start start, Connection con, AppIdenti private static int insertDevice_Transaction(Start start, Connection con, AppIdentifier appIdentifier, TOTPDevice device) throws SQLException, StorageQueryException { String QUERY = "INSERT INTO " + Config.getConfig(start).getTotpUserDevicesTable() - + " (app_id, user_id, device_name, secret_key, period, skew, verified) VALUES (?, ?, ?, ?, ?, ?, ?)"; + + " (app_id, user_id, device_name, secret_key, period, skew, verified, created_at) VALUES (?, ?, ?, ?, ?, ?, ?, ?)"; return update(con, QUERY, pst -> { pst.setString(1, appIdentifier.getAppId()); @@ -98,25 +99,31 @@ private static int insertDevice_Transaction(Start start, Connection con, AppIden pst.setInt(5, device.period); pst.setInt(6, device.skew); pst.setBoolean(7, device.verified); + pst.setLong(8, device.createdAt); }); } - public static void createDevice(Start start, AppIdentifier appIdentifier, TOTPDevice device) - throws StorageQueryException, StorageTransactionLogicException { - start.startTransaction(con -> { - Connection sqlCon = (Connection) con.getConnection(); - - try { - insertUser_Transaction(start, sqlCon, appIdentifier, device.userId); - insertDevice_Transaction(start, sqlCon, appIdentifier, device); - sqlCon.commit(); - } catch (SQLException e) { - throw new StorageTransactionLogicException(e); - } + public static void createDevice_Transaction(Start start, Connection sqlCon, AppIdentifier appIdentifier, TOTPDevice device) + throws SQLException, StorageQueryException { + insertUser_Transaction(start, sqlCon, appIdentifier, device.userId); + insertDevice_Transaction(start, sqlCon, appIdentifier, device); + } + + public static TOTPDevice getDeviceByName_Transaction(Start start, Connection sqlCon, AppIdentifier appIdentifier, String userId, String deviceName) + throws SQLException, StorageQueryException { + String QUERY = "SELECT * FROM " + Config.getConfig(start).getTotpUserDevicesTable() + + " WHERE app_id = ? AND user_id = ? AND device_name = ? FOR UPDATE;"; + return execute(sqlCon, QUERY, pst -> { + pst.setString(1, appIdentifier.getAppId()); + pst.setString(2, userId); + pst.setString(3, deviceName); + }, result -> { + if (result.next()) { + return TOTPDeviceRowMapper.getInstance().map(result); + } return null; }); - return; } public static int markDeviceAsVerified(Start start, AppIdentifier appIdentifier, String userId, String deviceName) @@ -288,7 +295,8 @@ public TOTPDevice map(ResultSet result) throws SQLException { result.getString("secret_key"), result.getInt("period"), result.getInt("skew"), - result.getBoolean("verified")); + result.getBoolean("verified"), + result.getLong("created_at")); } } diff --git a/src/main/java/io/supertokens/storage/mysql/queries/UserIdMappingQueries.java b/src/main/java/io/supertokens/storage/mysql/queries/UserIdMappingQueries.java index 973c030..355d29e 100644 --- a/src/main/java/io/supertokens/storage/mysql/queries/UserIdMappingQueries.java +++ b/src/main/java/io/supertokens/storage/mysql/queries/UserIdMappingQueries.java @@ -117,7 +117,8 @@ public static UserIdMapping[] getUserIdMappingWithEitherSuperTokensUserIdOrExter } - public static HashMap getUserIdMappingWithUserIds(Start start, List userIds) + public static HashMap getUserIdMappingWithUserIds(Start start, + AppIdentifier appIdentifier, List userIds) throws SQLException, StorageQueryException { if (userIds.size() == 0) { @@ -126,7 +127,8 @@ public static HashMap getUserIdMappingWithUserIds(Start start, L // No need to filter based on tenantId because the id list is already filtered for a tenant StringBuilder QUERY = new StringBuilder( - "SELECT * FROM " + Config.getConfig(start).getUserIdMappingTable() + " WHERE supertokens_user_id IN ("); + "SELECT * FROM " + Config.getConfig(start).getUserIdMappingTable() + " WHERE app_id = ? AND " + + "supertokens_user_id IN ("); for (int i = 0; i < userIds.size(); i++) { QUERY.append("?"); if (i != userIds.size() - 1) { @@ -136,9 +138,10 @@ public static HashMap getUserIdMappingWithUserIds(Start start, L } QUERY.append(")"); return execute(start, QUERY.toString(), pst -> { + pst.setString(1, appIdentifier.getAppId()); for (int i = 0; i < userIds.size(); i++) { - // i+1 cause this starts with 1 and not 0 - pst.setString(i + 1, userIds.get(i)); + // i+2 cause this starts with 1 and not 0, 1 is appId + pst.setString(i + 2, userIds.get(i)); } }, result -> { HashMap userIdMappings = new HashMap<>(); @@ -150,7 +153,9 @@ public static HashMap getUserIdMappingWithUserIds(Start start, L }); } - public static HashMap getUserIdMappingWithUserIds_Transaction(Start start, Connection sqlCon, List userIds) + public static HashMap getUserIdMappingWithUserIds_Transaction(Start start, Connection sqlCon, + AppIdentifier appIdentifier, + List userIds) throws SQLException, StorageQueryException { if (userIds.size() == 0) { @@ -159,7 +164,8 @@ public static HashMap getUserIdMappingWithUserIds_Transaction(St // No need to filter based on tenantId because the id list is already filtered for a tenant StringBuilder QUERY = new StringBuilder( - "SELECT * FROM " + Config.getConfig(start).getUserIdMappingTable() + " WHERE supertokens_user_id IN ("); + "SELECT * FROM " + Config.getConfig(start).getUserIdMappingTable() + " WHERE app_id = ? AND " + + "supertokens_user_id IN ("); for (int i = 0; i < userIds.size(); i++) { QUERY.append("?"); if (i != userIds.size() - 1) { @@ -169,9 +175,10 @@ public static HashMap getUserIdMappingWithUserIds_Transaction(St } QUERY.append(")"); return execute(sqlCon, QUERY.toString(), pst -> { + pst.setString(1, appIdentifier.getAppId()); for (int i = 0; i < userIds.size(); i++) { - // i+1 cause this starts with 1 and not 0 - pst.setString(i + 1, userIds.get(i)); + // i+2 cause this starts with 1 and not 0, 1 is appId + pst.setString(i + 2, userIds.get(i)); } }, result -> { HashMap userIdMappings = new HashMap<>(); diff --git a/src/main/java/io/supertokens/storage/mysql/queries/multitenancy/MfaSqlHelper.java b/src/main/java/io/supertokens/storage/mysql/queries/multitenancy/MfaSqlHelper.java new file mode 100644 index 0000000..9da3eb5 --- /dev/null +++ b/src/main/java/io/supertokens/storage/mysql/queries/multitenancy/MfaSqlHelper.java @@ -0,0 +1,116 @@ +/* + * Copyright (c) 2023, VRAI Labs and/or its affiliates. All rights reserved. + * + * This software is licensed under the Apache License, Version 2.0 (the + * "License") as published by the Apache Software Foundation. + * + * You may not use this file except in compliance with the License. You may + * obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package io.supertokens.storage.mysql.queries.multitenancy; + +import io.supertokens.pluginInterface.exceptions.StorageQueryException; +import io.supertokens.pluginInterface.multitenancy.TenantIdentifier; +import io.supertokens.storage.mysql.Start; + +import java.sql.Connection; +import java.sql.SQLException; +import java.util.*; + +import static io.supertokens.storage.mysql.QueryExecutorTemplate.execute; +import static io.supertokens.storage.mysql.QueryExecutorTemplate.update; +import static io.supertokens.storage.mysql.config.Config.getConfig; + +public class MfaSqlHelper { + public static HashMap selectAllFirstFactors(Start start) + throws SQLException, StorageQueryException { + String QUERY = "SELECT connection_uri_domain, app_id, tenant_id, factor_id FROM " + + getConfig(start).getTenantFirstFactorsTable() + ";"; + return execute(start, QUERY, pst -> {}, result -> { + HashMap> firstFactors = new HashMap<>(); + + while (result.next()) { + TenantIdentifier tenantIdentifier = new TenantIdentifier(result.getString("connection_uri_domain"), result.getString("app_id"), result.getString("tenant_id")); + if (!firstFactors.containsKey(tenantIdentifier)) { + firstFactors.put(tenantIdentifier, new ArrayList<>()); + } + + firstFactors.get(tenantIdentifier).add(result.getString("factor_id")); + } + + HashMap finalResult = new HashMap<>(); + for (TenantIdentifier tenantIdentifier : firstFactors.keySet()) { + finalResult.put(tenantIdentifier, firstFactors.get(tenantIdentifier).toArray(new String[0])); + } + + return finalResult; + }); + } + + public static HashMap selectAllRequiredSecondaryFactors(Start start) + throws SQLException, StorageQueryException { + String QUERY = "SELECT connection_uri_domain, app_id, tenant_id, factor_id FROM " + + getConfig(start).getTenantRequiredSecondaryFactorsTable() + ";"; + return execute(start, QUERY, pst -> {}, result -> { + HashMap> defaultRequiredFactors = new HashMap<>(); + + while (result.next()) { + TenantIdentifier tenantIdentifier = new TenantIdentifier(result.getString("connection_uri_domain"), + result.getString("app_id"), result.getString("tenant_id")); + if (!defaultRequiredFactors.containsKey(tenantIdentifier)) { + defaultRequiredFactors.put(tenantIdentifier, new ArrayList<>()); + } + + defaultRequiredFactors.get(tenantIdentifier).add(result.getString("factor_id")); + } + + HashMap finalResult = new HashMap<>(); + for (TenantIdentifier tenantIdentifier : defaultRequiredFactors.keySet()) { + finalResult.put(tenantIdentifier, defaultRequiredFactors.get(tenantIdentifier).toArray(new String[0])); + } + + return finalResult; + }); + } + + public static void createFirstFactors(Start start, Connection sqlCon, TenantIdentifier tenantIdentifier, String[] firstFactors) + throws SQLException, StorageQueryException { + if (firstFactors == null || firstFactors.length == 0) { + return; + } + + String QUERY = "INSERT INTO " + getConfig(start).getTenantFirstFactorsTable() + "(connection_uri_domain, app_id, tenant_id, factor_id) VALUES (?, ?, ?, ?);"; + for (String factorId : new HashSet<>(Arrays.asList(firstFactors))) { + update(sqlCon, QUERY, pst -> { + pst.setString(1, tenantIdentifier.getConnectionUriDomain()); + pst.setString(2, tenantIdentifier.getAppId()); + pst.setString(3, tenantIdentifier.getTenantId()); + pst.setString(4, factorId); + }); + } + } + + public static void createRequiredSecondaryFactors(Start start, Connection sqlCon, TenantIdentifier tenantIdentifier, String[] requiredSecondaryFactors) + throws SQLException, StorageQueryException { + if (requiredSecondaryFactors == null || requiredSecondaryFactors.length == 0) { + return; + } + + String QUERY = "INSERT INTO " + getConfig(start).getTenantRequiredSecondaryFactorsTable() + "(connection_uri_domain, app_id, tenant_id, factor_id) VALUES (?, ?, ?, ?);"; + for (String factorId : requiredSecondaryFactors) { + update(sqlCon, QUERY, pst -> { + pst.setString(1, tenantIdentifier.getConnectionUriDomain()); + pst.setString(2, tenantIdentifier.getAppId()); + pst.setString(3, tenantIdentifier.getTenantId()); + pst.setString(4, factorId); + }); + } + } +} diff --git a/src/main/java/io/supertokens/storage/mysql/queries/multitenancy/TenantConfigSQLHelper.java b/src/main/java/io/supertokens/storage/mysql/queries/multitenancy/TenantConfigSQLHelper.java index 7c82f71..af3931c 100644 --- a/src/main/java/io/supertokens/storage/mysql/queries/multitenancy/TenantConfigSQLHelper.java +++ b/src/main/java/io/supertokens/storage/mysql/queries/multitenancy/TenantConfigSQLHelper.java @@ -37,13 +37,17 @@ public class TenantConfigSQLHelper { public static class TenantConfigRowMapper implements RowMapper { ThirdPartyConfig.Provider[] providers; + String[] firstFactors; + String[] requiredSecondaryFactors; - private TenantConfigRowMapper(ThirdPartyConfig.Provider[] providers) { + private TenantConfigRowMapper(ThirdPartyConfig.Provider[] providers, String[] firstFactors, String[] requiredSecondaryFactors) { this.providers = providers; + this.firstFactors = firstFactors; + this.requiredSecondaryFactors = requiredSecondaryFactors; } - public static TenantConfigRowMapper getInstance(ThirdPartyConfig.Provider[] providers) { - return new TenantConfigRowMapper(providers); + public static TenantConfigRowMapper getInstance(ThirdPartyConfig.Provider[] providers, String[] firstFactors, String[] requiredSecondaryFactors) { + return new TenantConfigRowMapper(providers, firstFactors, requiredSecondaryFactors); } @Override @@ -54,6 +58,8 @@ public TenantConfig map(ResultSet result) throws StorageQueryException { new EmailPasswordConfig(result.getBoolean("email_password_enabled")), new ThirdPartyConfig(result.getBoolean("third_party_enabled"), this.providers), new PasswordlessConfig(result.getBoolean("passwordless_enabled")), + firstFactors.length == 0 ? null : firstFactors, + requiredSecondaryFactors.length == 0 ? null : requiredSecondaryFactors, JsonUtils.stringToJsonObject(result.getString("core_config")) ); } catch (Exception e) { @@ -62,9 +68,10 @@ public TenantConfig map(ResultSet result) throws StorageQueryException { } } - public static TenantConfig[] selectAll(Start start, HashMap> providerMap) + public static TenantConfig[] selectAll(Start start, HashMap> providerMap, HashMap firstFactorsMap, HashMap requiredSecondaryFactorsMap) throws SQLException, StorageQueryException { - String QUERY = "SELECT connection_uri_domain, app_id, tenant_id, core_config, email_password_enabled, passwordless_enabled, third_party_enabled FROM " + String QUERY = "SELECT connection_uri_domain, app_id, tenant_id, core_config," + + " email_password_enabled, passwordless_enabled, third_party_enabled FROM " + getConfig(start).getTenantConfigsTable() + ";"; TenantConfig[] tenantConfigs = execute(start, QUERY, pst -> {}, result -> { @@ -75,7 +82,11 @@ public static TenantConfig[] selectAll(Start start, HashMap { diff --git a/src/test/java/io/supertokens/storage/mysql/test/DbConnectionPoolTest.java b/src/test/java/io/supertokens/storage/mysql/test/DbConnectionPoolTest.java index dfc72fb..ccfeb27 100644 --- a/src/test/java/io/supertokens/storage/mysql/test/DbConnectionPoolTest.java +++ b/src/test/java/io/supertokens/storage/mysql/test/DbConnectionPoolTest.java @@ -79,7 +79,7 @@ public void testActiveConnectionsWithTenants() throws Exception { new EmailPasswordConfig(true), new ThirdPartyConfig(true, null), new PasswordlessConfig(true), - config + null, null, config ), false); Thread.sleep(1000); // let the new tenant be ready @@ -94,7 +94,7 @@ public void testActiveConnectionsWithTenants() throws Exception { new EmailPasswordConfig(true), new ThirdPartyConfig(true, null), new PasswordlessConfig(true), - config + null, null, config ), false); Thread.sleep(2000); // let the new tenant be ready @@ -137,7 +137,7 @@ public void testDownTimeWhenChangingConnectionPoolSize() throws Exception { new EmailPasswordConfig(true), new ThirdPartyConfig(true, null), new PasswordlessConfig(true), - config + null, null, config ), false); Thread.sleep(5000); // let the new tenant be ready @@ -197,7 +197,7 @@ public void testDownTimeWhenChangingConnectionPoolSize() throws Exception { new EmailPasswordConfig(true), new ThirdPartyConfig(true, null), new PasswordlessConfig(true), - config + null, null, config ), false); Thread.sleep(3000); // let the new tenant be ready @@ -280,7 +280,7 @@ public void testMinimumIdleConnectionForTenants() throws Exception { new EmailPasswordConfig(true), new ThirdPartyConfig(true, null), new PasswordlessConfig(true), - config + null, null, config ), false); Thread.sleep(1000); // let the new tenant be ready @@ -296,7 +296,7 @@ public void testMinimumIdleConnectionForTenants() throws Exception { new EmailPasswordConfig(true), new ThirdPartyConfig(true, null), new PasswordlessConfig(true), - config + null, null, config ), false); Thread.sleep(2000); // let the new tenant be ready @@ -339,7 +339,7 @@ public void testIdleConnectionTimeout() throws Exception { new EmailPasswordConfig(true), new ThirdPartyConfig(true, null), new PasswordlessConfig(true), - config + null, null, config ), false); Thread.sleep(3000); // let the new tenant be ready diff --git a/src/test/java/io/supertokens/storage/mysql/test/DeadlockTest.java b/src/test/java/io/supertokens/storage/mysql/test/DeadlockTest.java index 9044c1a..db3633d 100644 --- a/src/test/java/io/supertokens/storage/mysql/test/DeadlockTest.java +++ b/src/test/java/io/supertokens/storage/mysql/test/DeadlockTest.java @@ -33,7 +33,7 @@ import io.supertokens.pluginInterface.sqlStorage.SQLStorage; import io.supertokens.pluginInterface.totp.TOTPDevice; import io.supertokens.pluginInterface.totp.TOTPUsedCode; -import io.supertokens.pluginInterface.totp.exception.TotpNotEnabledException; +import io.supertokens.pluginInterface.totp.exception.UnknownTotpUserIdException; import io.supertokens.pluginInterface.totp.exception.UsedCodeAlreadyExistsException; import io.supertokens.pluginInterface.totp.sqlStorage.TOTPSQLStorage; import io.supertokens.pluginInterface.sqlStorage.SQLStorage.TransactionIsolationLevel; @@ -280,7 +280,7 @@ public void testConcurrentDeleteAndInsert() throws Exception { // Create a device as well as a user: TOTPSQLStorage totpStorage = (TOTPSQLStorage) StorageLayer.getStorage(process.getProcess()); - TOTPDevice device = new TOTPDevice("user", "d1", "secret", 30, 1, false); + TOTPDevice device = new TOTPDevice("user", "d1", "secret", 30, 1, false, System.currentTimeMillis()); totpStorage.createDevice(TenantIdentifier.BASE_TENANT.toAppIdentifier(), device); long now = System.currentTimeMillis(); @@ -290,7 +290,7 @@ public void testConcurrentDeleteAndInsert() throws Exception { try { totpStorage.insertUsedCode_Transaction(con, TenantIdentifier.BASE_TENANT, code); totpStorage.commitTransaction(con); - } catch (TotpNotEnabledException | UsedCodeAlreadyExistsException | TenantOrAppNotFoundException e) { + } catch (UnknownTotpUserIdException | UsedCodeAlreadyExistsException | TenantOrAppNotFoundException e) { // This should not happen throw new StorageTransactionLogicException(e); } @@ -392,7 +392,7 @@ public void testConcurrentDeleteAndInsert() throws Exception { TOTPUsedCode code2 = new TOTPUsedCode("user", "1234", false, nextDay, now + 1); try { totpStorage.insertUsedCode_Transaction(con, TenantIdentifier.BASE_TENANT, code2); - } catch (TotpNotEnabledException | UsedCodeAlreadyExistsException | TenantOrAppNotFoundException e) { + } catch (UnknownTotpUserIdException | UsedCodeAlreadyExistsException | TenantOrAppNotFoundException e) { // This should not happen throw new StorageTransactionLogicException(e); } @@ -405,7 +405,7 @@ public void testConcurrentDeleteAndInsert() throws Exception { } catch (StorageTransactionLogicException e) { Exception e2 = e.actualException; - if (e2 instanceof TotpNotEnabledException) { + if (e2 instanceof UnknownTotpUserIdException) { t2Failed.set(true); } } catch (StorageQueryException e) { @@ -445,7 +445,7 @@ public void testConcurrentDeleteAndUpdate() throws Exception { // Create a device as well as a user: TOTPSQLStorage totpStorage = (TOTPSQLStorage) StorageLayer.getStorage(process.getProcess()); - TOTPDevice device = new TOTPDevice("user", "d1", "secret", 30, 1, false); + TOTPDevice device = new TOTPDevice("user", "d1", "secret", 30, 1, false, System.currentTimeMillis()); totpStorage.createDevice(TenantIdentifier.BASE_TENANT.toAppIdentifier(), device); long now = System.currentTimeMillis(); @@ -455,7 +455,7 @@ public void testConcurrentDeleteAndUpdate() throws Exception { try { totpStorage.insertUsedCode_Transaction(con, TenantIdentifier.BASE_TENANT, code); totpStorage.commitTransaction(con); - } catch (TotpNotEnabledException | UsedCodeAlreadyExistsException | TenantOrAppNotFoundException e) { + } catch (UnknownTotpUserIdException | UsedCodeAlreadyExistsException | TenantOrAppNotFoundException e) { // This should not happen throw new StorageTransactionLogicException(e); } @@ -573,7 +573,7 @@ public void testConcurrentDeleteAndUpdate() throws Exception { } catch (StorageTransactionLogicException e) { Exception e2 = e.actualException; - if (e2 instanceof TotpNotEnabledException) { + if (e2 instanceof UnknownTotpUserIdException) { t2Failed.set(true); } } catch (StorageQueryException e) { diff --git a/src/test/java/io/supertokens/storage/mysql/test/LoggingTest.java b/src/test/java/io/supertokens/storage/mysql/test/LoggingTest.java index 2a54c8e..96dbf2a 100644 --- a/src/test/java/io/supertokens/storage/mysql/test/LoggingTest.java +++ b/src/test/java/io/supertokens/storage/mysql/test/LoggingTest.java @@ -286,6 +286,7 @@ public void confirmHikariLoggerClosedOnlyWhenProcessEnds() throws Exception { new EmailPasswordConfig(true), new ThirdPartyConfig(true, null), new PasswordlessConfig(true), + null, null, config ), false); @@ -520,8 +521,7 @@ public void testDBPasswordIsNotLoggedWhenTenantIsCreated() throws Exception { new EmailPasswordConfig(true), new ThirdPartyConfig(true, null), new PasswordlessConfig(true), - config - )); + null, null, config)); process.kill(); assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STOPPED)); @@ -565,7 +565,7 @@ public void testDBPasswordIsNotLoggedWhenTenantIsCreated() throws Exception { new EmailPasswordConfig(true), new ThirdPartyConfig(true, null), new PasswordlessConfig(true), - new JsonObject())); + null, null, new JsonObject())); } catch (Exception e) { diff --git a/src/test/java/io/supertokens/storage/mysql/test/OneMillionUsersTest.java b/src/test/java/io/supertokens/storage/mysql/test/OneMillionUsersTest.java new file mode 100644 index 0000000..3714a37 --- /dev/null +++ b/src/test/java/io/supertokens/storage/mysql/test/OneMillionUsersTest.java @@ -0,0 +1,898 @@ +/* + * Copyright (c) 2024, VRAI Labs and/or its affiliates. All rights reserved. + * + * This software is licensed under the Apache License, Version 2.0 (the + * "License") as published by the Apache Software Foundation. + * + * You may not use this file except in compliance with the License. You may + * obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package io.supertokens.storage.mysql.test; + +import com.google.gson.JsonArray; +import com.google.gson.JsonElement; +import com.google.gson.JsonObject; +import io.supertokens.ActiveUsers; +import io.supertokens.Main; +import io.supertokens.ProcessState; +import io.supertokens.authRecipe.AuthRecipe; +import io.supertokens.authRecipe.UserPaginationContainer; +import io.supertokens.emailpassword.EmailPassword; +import io.supertokens.emailpassword.ParsedFirebaseSCryptResponse; +import io.supertokens.featureflag.EE_FEATURES; +import io.supertokens.featureflag.FeatureFlagTestContent; +import io.supertokens.passwordless.Passwordless; +import io.supertokens.pluginInterface.authRecipe.AuthRecipeUserInfo; +import io.supertokens.pluginInterface.authRecipe.LoginMethod; +import io.supertokens.pluginInterface.authRecipe.sqlStorage.AuthRecipeSQLStorage; +import io.supertokens.pluginInterface.emailpassword.sqlStorage.EmailPasswordSQLStorage; +import io.supertokens.pluginInterface.multitenancy.AppIdentifier; +import io.supertokens.pluginInterface.multitenancy.TenantIdentifier; +import io.supertokens.pluginInterface.passwordless.sqlStorage.PasswordlessSQLStorage; +import io.supertokens.pluginInterface.thirdparty.sqlStorage.ThirdPartySQLStorage; +import io.supertokens.session.Session; +import io.supertokens.session.info.SessionInformationHolder; +import io.supertokens.storage.mysql.test.httpRequest.HttpRequestForTesting; +import io.supertokens.storageLayer.StorageLayer; +import io.supertokens.thirdparty.ThirdParty; +import io.supertokens.useridmapping.UserIdMapping; +import io.supertokens.usermetadata.UserMetadata; +import io.supertokens.userroles.UserRoles; +import io.supertokens.utils.SemVer; +import org.junit.AfterClass; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TestRule; + +import java.util.*; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicLong; +import java.util.function.Supplier; + +import static org.junit.Assert.*; +import static org.junit.Assert.assertEquals; + +public class OneMillionUsersTest { + @Rule + public TestRule watchman = Utils.getOnFailure(); + + @AfterClass + public static void afterTesting() { + Utils.afterTesting(); + } + + @Before + public void beforeEach() { + Utils.reset(); + } + + static int TOTAL_USERS = 1000000; + static int NUM_THREADS = 16; + + Object lock = new Object(); + Set allUserIds = new HashSet<>(); + Set allPrimaryUserIds = new HashSet<>(); + Map userIdMappings = new HashMap<>(); + Map primaryUserIdMappings = new HashMap<>(); + + private void createEmailPasswordUsers(Main main) throws Exception { + System.out.println("Creating emailpassword users..."); + + int firebaseMemCost = 14; + int firebaseRounds = 8; + String firebaseSaltSeparator = "Bw=="; + + String salt = "/cj0jC1br5o4+w=="; + String passwordHash = "qZM035es5AXYqavsKD6/rhtxg7t5PhcyRgv5blc3doYbChX8keMfQLq1ra96O2Pf2TP/eZrR5xtPCYN6mX3ESA" + + "=="; + String combinedPasswordHash = "$" + ParsedFirebaseSCryptResponse.FIREBASE_SCRYPT_PREFIX + "$" + passwordHash + + "$" + salt + "$m=" + firebaseMemCost + "$r=" + firebaseRounds + "$s=" + firebaseSaltSeparator; + + ExecutorService es = Executors.newFixedThreadPool(NUM_THREADS); + + EmailPasswordSQLStorage storage = (EmailPasswordSQLStorage) StorageLayer.getBaseStorage(main); + + for (int i = 0; i < TOTAL_USERS / 4; i++) { + int finalI = i; + es.execute(() -> { + try { + String userId = io.supertokens.utils.Utils.getUUID(); + long timeJoined = System.currentTimeMillis(); + + storage.signUp(TenantIdentifier.BASE_TENANT, userId, "eptest" + finalI + "@example.com", combinedPasswordHash, + timeJoined); + synchronized (lock) { + allUserIds.add(userId); + } + } catch (Exception e) { + throw new RuntimeException(e); + } + if (finalI % 10000 == 9999) { + System.out.println("Created " + ((finalI +1)) + " users"); + } + }); + } + + es.shutdown(); + es.awaitTermination(10, TimeUnit.MINUTES); + } + + private void createPasswordlessUsersWithEmail(Main main) throws Exception { + System.out.println("Creating passwordless (email) users..."); + + ExecutorService es = Executors.newFixedThreadPool(NUM_THREADS); + PasswordlessSQLStorage storage = (PasswordlessSQLStorage) StorageLayer.getBaseStorage(main); + + for (int i = 0; i < TOTAL_USERS / 4; i++) { + int finalI = i; + es.execute(() -> { + String userId = io.supertokens.utils.Utils.getUUID(); + long timeJoined = System.currentTimeMillis(); + try { + storage.createUser(TenantIdentifier.BASE_TENANT, userId, "pltest" + finalI + "@example.com", null, timeJoined); + synchronized (lock) { + allUserIds.add(userId); + } + } catch (Exception e) { + throw new RuntimeException(e); + } + + if (finalI % 10000 == 9999) { + System.out.println("Created " + ((finalI +1)) + " users"); + } + }); + } + + es.shutdown(); + es.awaitTermination(10, TimeUnit.MINUTES); + } + + private void createPasswordlessUsersWithPhone(Main main) throws Exception { + System.out.println("Creating passwordless (phone) users..."); + + ExecutorService es = Executors.newFixedThreadPool(NUM_THREADS); + PasswordlessSQLStorage storage = (PasswordlessSQLStorage) StorageLayer.getBaseStorage(main); + + for (int i = 0; i < TOTAL_USERS / 4; i++) { + int finalI = i; + es.execute(() -> { + String userId = io.supertokens.utils.Utils.getUUID(); + long timeJoined = System.currentTimeMillis(); + try { + storage.createUser(TenantIdentifier.BASE_TENANT, userId, null, "+91987654" + finalI, timeJoined); + synchronized (lock) { + allUserIds.add(userId); + } + } catch (Exception e) { + throw new RuntimeException(e); + } + + if (finalI % 10000 == 9999) { + System.out.println("Created " + ((finalI +1)) + " users"); + } + }); + } + + es.shutdown(); + es.awaitTermination(10, TimeUnit.MINUTES); + } + + private void createThirdpartyUsers(Main main) throws Exception { + System.out.println("Creating thirdparty users..."); + + ExecutorService es = Executors.newFixedThreadPool(NUM_THREADS); + ThirdPartySQLStorage storage = (ThirdPartySQLStorage) StorageLayer.getBaseStorage(main); + + for (int i = 0; i < TOTAL_USERS / 4; i++) { + int finalI = i; + es.execute(() -> { + String userId = io.supertokens.utils.Utils.getUUID(); + long timeJoined = System.currentTimeMillis(); + + try { + storage.signUp(TenantIdentifier.BASE_TENANT, userId, "tptest" + finalI + "@example.com", new LoginMethod.ThirdParty("google", "googleid" + finalI), timeJoined ); + synchronized (lock) { + allUserIds.add(userId); + } + } catch (Exception e) { + throw new RuntimeException(e); + } + + if (finalI % 10000 == 9999) { + System.out.println("Created " + (finalI +1) + " users"); + } + }); + } + + es.shutdown(); + es.awaitTermination(10, TimeUnit.MINUTES); + } + + private void createOneMillionUsers(Main main) throws Exception { + Thread.sleep(5000); + + createEmailPasswordUsers(main); + createPasswordlessUsersWithEmail(main); + createPasswordlessUsersWithPhone(main); + createThirdpartyUsers(main); + } + + private void createUserIdMappings(Main main) throws Exception { + System.out.println("Creating user id mappings..."); + + ExecutorService es = Executors.newFixedThreadPool(NUM_THREADS); + AtomicLong usersUpdated = new AtomicLong(0); + + for (String userId : allUserIds) { + es.execute(() -> { + String extUserId = "ext" + UUID.randomUUID().toString(); + try { + UserIdMapping.createUserIdMapping(main, userId, extUserId, null, false); + synchronized (lock) { + userIdMappings.put(userId, extUserId); + } + } catch (Exception e) { + throw new RuntimeException(e); + } + + long count = usersUpdated.incrementAndGet(); + if (count % 10000 == 9999) { + System.out.println("Updated " + (count) + " users"); + } + }); + } + + es.shutdown(); + es.awaitTermination(10, TimeUnit.MINUTES); + } + + private void createUserData(Main main) throws Exception { + System.out.println("Creating user data..."); + + ExecutorService es = Executors.newFixedThreadPool(NUM_THREADS / 2); + + for (String userId : allPrimaryUserIds) { + es.execute(() -> { + Random random = new Random(); + + // User Metadata + JsonObject metadata = new JsonObject(); + metadata.addProperty("random", random.nextDouble()); + + try { + UserMetadata.updateUserMetadata(main, userIdMappings.get(userId), metadata); + + // User Roles + if (random.nextBoolean()) { + UserRoles.addRoleToUser(main, userIdMappings.get(userId), "admin"); + } else { + UserRoles.addRoleToUser(main, userIdMappings.get(userId), "user"); + } + } catch (Exception e) { + throw new RuntimeException(e); + } + }); + } + + es.shutdown(); + es.awaitTermination(1, TimeUnit.MINUTES); + } + + private void doAccountLinking(Main main) throws Exception { + Set userIds = new HashSet<>(); + userIds.addAll(allUserIds); + + assertEquals(TOTAL_USERS, userIds.size()); + + AtomicLong accountsLinked = new AtomicLong(0); + + ExecutorService es = Executors.newFixedThreadPool(NUM_THREADS); + + while (userIds.size() > 0) { + int numberOfItemsToPick = Math.min(new Random().nextInt(4) + 1, userIds.size()); + String[] userIdsArray = new String[numberOfItemsToPick]; + + Iterator iterator = userIds.iterator(); + for (int i = 0; i < numberOfItemsToPick; i++) { + userIdsArray[i] = iterator.next(); + iterator.remove(); + } + + AuthRecipeSQLStorage storage = (AuthRecipeSQLStorage) StorageLayer.getBaseStorage(main); + + es.execute(() -> { + try { + storage.startTransaction(con -> { + storage.makePrimaryUser_Transaction(new AppIdentifier(null, null), con, userIdsArray[0]); + storage.commitTransaction(con); + return null; + }); + } catch (Exception e) { + throw new RuntimeException(e); + } + + try { + for (int i = 1; i < userIdsArray.length; i++) { + int finalI = i; + storage.startTransaction(con -> { + storage.linkAccounts_Transaction(new AppIdentifier(null, null), con, userIdsArray[finalI], + userIdsArray[0]); + storage.commitTransaction(con); + return null; + }); + } + } catch (Exception e) { + throw new RuntimeException(e); + } + + synchronized (lock) { + allPrimaryUserIds.add(userIdsArray[0]); + for (String userId : userIdsArray) { + primaryUserIdMappings.put(userId, userIdsArray[0]); + } + } + + long total = accountsLinked.addAndGet(userIdsArray.length); + if (total % 10000 > 9996) { + System.out.println("Linked " + (accountsLinked) + " users"); + } + }); + } + + es.shutdown(); + es.awaitTermination(10, TimeUnit.MINUTES); + } + + private static String accessToken; + private static String sessionUserId; + + private void createSessions(Main main) throws Exception { + System.out.println("Creating sessions..."); + + ExecutorService es = Executors.newFixedThreadPool(NUM_THREADS); + + for (String userId : allUserIds) { + String finalUserId = userId; + es.execute(() -> { + try { + SessionInformationHolder session = Session.createNewSession(main, + userIdMappings.get(finalUserId), new JsonObject(), new JsonObject()); + + if (new Random().nextFloat() < 0.05) { + synchronized (lock) { + accessToken = session.accessToken.token; + sessionUserId = userIdMappings.get(primaryUserIdMappings.get(finalUserId)); + } + } + + } catch (Exception e) { + throw new RuntimeException(e); + } + }); + } + + es.shutdown(); + es.awaitTermination(10, TimeUnit.MINUTES); + } + + @Test + public void testCreatingOneMillionUsers() throws Exception { + if (System.getenv("ONE_MILLION_USERS_TEST") == null) { + return; + } + + String[] args = {"../"}; + TestingProcessManager.TestingProcess process = TestingProcessManager.start(args, false); + Utils.setValueInConfig("firebase_password_hashing_signer_key", + "gRhC3eDeQOdyEn4bMd9c6kxguWVmcIVq/SKa0JDPFeM6TcEevkaW56sIWfx88OHbJKnCXdWscZx0l2WbCJ1wbg=="); + Utils.setValueInConfig("postgresql_connection_pool_size", "500"); + + FeatureFlagTestContent.getInstance(process.getProcess()) + .setKeyValue(FeatureFlagTestContent.ENABLED_FEATURES, new EE_FEATURES[]{ + EE_FEATURES.ACCOUNT_LINKING, EE_FEATURES.MULTI_TENANCY}); + process.startProcess(); + assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STARTED)); + + AtomicBoolean memoryCheckRunning = new AtomicBoolean(true); + AtomicLong maxMemory = new AtomicLong(0); + + { + long st = System.currentTimeMillis(); + createOneMillionUsers(process.getProcess()); + long en = System.currentTimeMillis(); + System.out.println("Time taken to create " + TOTAL_USERS + " users: " + ((en - st) / 1000) + " sec"); + assertEquals(TOTAL_USERS, AuthRecipe.getUsersCount(process.getProcess(), null)); + } + + { + long st = System.currentTimeMillis(); + doAccountLinking(process.getProcess()); + long en = System.currentTimeMillis(); + System.out.println("Time taken to link accounts: " + ((en - st) / 1000) + " sec"); + } + + { + long st = System.currentTimeMillis(); + createUserIdMappings(process.getProcess()); + long en = System.currentTimeMillis(); + System.out.println("Time taken to create user id mappings: " + ((en - st) / 1000) + " sec"); + } + + { + UserRoles.createNewRoleOrModifyItsPermissions(process.getProcess(), "admin", new String[]{"p1"}); + UserRoles.createNewRoleOrModifyItsPermissions(process.getProcess(), "user", new String[]{"p2"}); + long st = System.currentTimeMillis(); + createUserData(process.getProcess()); + long en = System.currentTimeMillis(); + System.out.println("Time taken to create user data: " + ((en - st) / 1000) + " sec"); + } + + { + long st = System.currentTimeMillis(); + createSessions(process.getProcess()); + long en = System.currentTimeMillis(); + System.out.println("Time taken to create sessions: " + ((en - st) / 1000) + " sec"); + } + + sanityCheckAPIs(process.getProcess()); + allUserIds.clear(); + allPrimaryUserIds.clear(); + userIdMappings.clear(); + primaryUserIdMappings.clear(); + + process.kill(false); + + Runtime.getRuntime().gc(); + System.gc(); + System.runFinalization(); + Thread.sleep(10000); + + + process = TestingProcessManager.start(args, false); + Utils.setValueInConfig("firebase_password_hashing_signer_key", + "gRhC3eDeQOdyEn4bMd9c6kxguWVmcIVq/SKa0JDPFeM6TcEevkaW56sIWfx88OHbJKnCXdWscZx0l2WbCJ1wbg=="); + Utils.setValueInConfig("postgresql_connection_pool_size", "500"); + + FeatureFlagTestContent.getInstance(process.getProcess()) + .setKeyValue(FeatureFlagTestContent.ENABLED_FEATURES, new EE_FEATURES[]{ + EE_FEATURES.ACCOUNT_LINKING, EE_FEATURES.MULTI_TENANCY}); + process.startProcess(); + assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STARTED)); + + Thread memoryChecker = new Thread(() -> { + while (memoryCheckRunning.get()) { + Runtime rt = Runtime.getRuntime(); + long total_mem = rt.totalMemory(); + long free_mem = rt.freeMemory(); + long used_mem = total_mem - free_mem; + + if (used_mem > maxMemory.get()) { + maxMemory.set(used_mem); + } + + try { + Thread.sleep(1000); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + } + }); + memoryChecker.start(); + + measureOperations(process.getProcess()); + + memoryCheckRunning.set(false); + memoryChecker.join(); + + System.out.println("Max memory used: " + (maxMemory.get() / (1024 * 1024)) + " MB"); + assert maxMemory.get() < 256 * 1024 * 1024; // must be less than 256 mb + + process.kill(); + assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STOPPED)); + } + + private void sanityCheckAPIs(Main main) throws Exception { + { // Email password sign in + JsonObject responseBody = new JsonObject(); + responseBody.addProperty("email", "eptest10@example.com"); + responseBody.addProperty("password", "testPass123"); + + Thread.sleep(1); // add a small delay to ensure a unique timestamp + long beforeSignIn = System.currentTimeMillis(); + + JsonObject signInResponse = HttpRequestForTesting.sendJsonPOSTRequest(main, "", + "http://localhost:3567/recipe/signin", responseBody, 1000, 1000, null, SemVer.v4_0.get(), + "emailpassword"); + + assertEquals(signInResponse.get("status").getAsString(), "OK"); + assertEquals(signInResponse.entrySet().size(), 3); + + JsonObject jsonUser = signInResponse.get("user").getAsJsonObject(); + JsonArray emails = jsonUser.get("emails").getAsJsonArray(); + boolean found = false; + + for (JsonElement elem : emails) { + if (elem.getAsString().equals("eptest10@example.com")) { + found = true; + break; + } + } + + assertTrue(found); + + int activeUsers = ActiveUsers.countUsersActiveSince(main, beforeSignIn); + assert (activeUsers == 1); + } + + { // passwordless sign in + long startTs = System.currentTimeMillis(); + + String email = "pltest10@example.com"; + Passwordless.CreateCodeResponse createResp = Passwordless.createCode(main, email, null, null, null); + + JsonObject consumeCodeRequestBody = new JsonObject(); + consumeCodeRequestBody.addProperty("deviceId", createResp.deviceId); + consumeCodeRequestBody.addProperty("preAuthSessionId", createResp.deviceIdHash); + consumeCodeRequestBody.addProperty("userInputCode", createResp.userInputCode); + + JsonObject response = HttpRequestForTesting.sendJsonPOSTRequest(main, "", + "http://localhost:3567/recipe/signinup/code/consume", consumeCodeRequestBody, 1000, 1000, null, + SemVer.v5_0.get(), "passwordless"); + + assertEquals("OK", response.get("status").getAsString()); + assertEquals(false, response.get("createdNewUser").getAsBoolean()); + assert (response.has("user")); + + JsonObject jsonUser = response.get("user").getAsJsonObject(); + JsonArray emails = jsonUser.get("emails").getAsJsonArray(); + boolean found = false; + + for (JsonElement elem : emails) { + if (elem.getAsString().equals("pltest10@example.com")) { + found = true; + break; + } + } + + assertTrue(found); + + int activeUsers = ActiveUsers.countUsersActiveSince(main, startTs); + assert (activeUsers == 1); + } + + { // thirdparty sign in + long startTs = System.currentTimeMillis(); + JsonObject emailObject = new JsonObject(); + emailObject.addProperty("id", "tptest10@example.com"); + emailObject.addProperty("isVerified", true); + + JsonObject signUpRequestBody = new JsonObject(); + signUpRequestBody.addProperty("thirdPartyId", "google"); + signUpRequestBody.addProperty("thirdPartyUserId", "googleid10"); + signUpRequestBody.add("email", emailObject); + + JsonObject response = HttpRequestForTesting.sendJsonPOSTRequest(main, "", + "http://localhost:3567/recipe/signinup", signUpRequestBody, 1000, 1000, null, + SemVer.v4_0.get(), "thirdparty"); + + assertEquals("OK", response.get("status").getAsString()); + assertEquals(false, response.get("createdNewUser").getAsBoolean()); + assert (response.has("user")); + + JsonObject jsonUser = response.get("user").getAsJsonObject(); + JsonArray emails = jsonUser.get("emails").getAsJsonArray(); + boolean found = false; + + for (JsonElement elem : emails) { + if (elem.getAsString().equals("tptest10@example.com")) { + found = true; + break; + } + } + + assertTrue(found); + + int activeUsers = ActiveUsers.countUsersActiveSince(main, startTs); + assert (activeUsers == 1); + } + + { // session for user + JsonObject request = new JsonObject(); + request.addProperty("accessToken", accessToken); + request.addProperty("doAntiCsrfCheck", false); + request.addProperty("enableAntiCsrf", false); + request.addProperty("checkDatabase", false); + JsonObject response = HttpRequestForTesting.sendJsonPOSTRequest(main, "", + "http://localhost:3567/recipe/session/verify", request, 1000, 1000, null, + SemVer.v5_0.get(), "session"); + assertEquals("OK", response.get("status").getAsString()); + assertEquals(sessionUserId, response.get("session").getAsJsonObject().get("userId").getAsString()); + } + + { // check user roles + JsonObject responseBody = new JsonObject(); + responseBody.addProperty("email", "eptest10@example.com"); + responseBody.addProperty("password", "testPass123"); + + Thread.sleep(1); // add a small delay to ensure a unique timestamp + JsonObject signInResponse = HttpRequestForTesting.sendJsonPOSTRequest(main, "", + "http://localhost:3567/recipe/signin", responseBody, 1000, 1000, null, SemVer.v4_0.get(), + "emailpassword"); + + HashMap QUERY_PARAMS = new HashMap<>(); + QUERY_PARAMS.put("userId", signInResponse.get("user").getAsJsonObject().get("id").getAsString()); + JsonObject response = HttpRequestForTesting.sendGETRequest(main, "", + "http://localhost:3567/recipe/user/roles", QUERY_PARAMS, 1000, 1000, null, + SemVer.v2_14.get(), "userroles"); + + assertEquals(2, response.entrySet().size()); + assertEquals("OK", response.get("status").getAsString()); + + JsonArray userRolesArr = response.getAsJsonArray("roles"); + assertEquals(1, userRolesArr.size()); + assertTrue( + userRolesArr.get(0).getAsString().equals("admin") || userRolesArr.get(0).getAsString().equals("user") + ); + } + + { // check user metadata + HashMap QueryParams = new HashMap(); + QueryParams.put("userId", sessionUserId); + JsonObject resp = HttpRequestForTesting.sendGETRequest(main, "", + "http://localhost:3567/recipe/user/metadata", QueryParams, 1000, 1000, null, + SemVer.v2_13.get(), "usermetadata"); + + assertEquals(2, resp.entrySet().size()); + assertEquals("OK", resp.get("status").getAsString()); + assert (resp.has("metadata")); + JsonObject respMetadata = resp.getAsJsonObject("metadata"); + assertEquals(1, respMetadata.entrySet().size()); + } + } + + private void measureOperations(Main main) throws Exception { + AtomicLong errorCount = new AtomicLong(0); + { // Emailpassword sign up + System.out.println("Measure email password sign-ups"); + long time = measureTime(() -> { + ExecutorService es = Executors.newFixedThreadPool(100); + + for (int i = 0; i < 500; i++) { + int finalI = i; + es.execute(() -> { + try { + EmailPassword.signUp(main, "ep" + finalI + "@example.com", "password" + finalI); + } catch (Exception e) { + errorCount.incrementAndGet(); + throw new RuntimeException(e); + } + }); + } + + es.shutdown(); + try { + es.awaitTermination(5, TimeUnit.MINUTES); + } catch (InterruptedException e) { + errorCount.incrementAndGet(); + throw new RuntimeException(e); + } + return null; + }); + System.out.println("EP sign up " + time); + assert time < 15000; + } + { // Emailpassword sign in + System.out.println("Measure email password sign-ins"); + long time = measureTime(() -> { + ExecutorService es = Executors.newFixedThreadPool(100); + + for (int i = 0; i < 500; i++) { + int finalI = i; + es.execute(() -> { + try { + EmailPassword.signIn(main, "ep" + finalI + "@example.com", "password" + finalI); + } catch (Exception e) { + errorCount.incrementAndGet(); + throw new RuntimeException(e); + } + }); + } + + es.shutdown(); + try { + es.awaitTermination(5, TimeUnit.MINUTES); + } catch (InterruptedException e) { + errorCount.incrementAndGet(); + throw new RuntimeException(e); + } + return null; + }); + System.out.println("EP sign in " + time); + assert time < 15000; + } + { // Passwordless sign-ups + System.out.println("Measure passwordless sign-ups"); + long time = measureTime(() -> { + ExecutorService es = Executors.newFixedThreadPool(100); + for (int i = 0; i < 500; i++) { + int finalI = i; + es.execute(() -> { + try { + Passwordless.CreateCodeResponse code = Passwordless.createCode(main, + "pl" + finalI + "@example.com", null, null, null); + Passwordless.consumeCode(main, code.deviceId, code.deviceIdHash, code.userInputCode, null); + } catch (Exception e) { + errorCount.incrementAndGet(); + throw new RuntimeException(e); + } + }); + } + es.shutdown(); + try { + es.awaitTermination(5, TimeUnit.MINUTES); + } catch (InterruptedException e) { + errorCount.incrementAndGet(); + throw new RuntimeException(e); + } + return null; + }); + System.out.println("PL sign up " + time); + assert time < 5000; + } + { // Passwordless sign-ins + System.out.println("Measure passwordless sign-ins"); + long time = measureTime(() -> { + ExecutorService es = Executors.newFixedThreadPool(100); + for (int i = 0; i < 500; i++) { + int finalI = i; + es.execute(() -> { + try { + Passwordless.CreateCodeResponse code = Passwordless.createCode(main, + "pl" + finalI + "@example.com", null, null, null); + Passwordless.consumeCode(main, code.deviceId, code.deviceIdHash, code.userInputCode, null); + } catch (Exception e) { + errorCount.incrementAndGet(); + throw new RuntimeException(e); + } + }); + } + es.shutdown(); + try { + es.awaitTermination(5, TimeUnit.MINUTES); + } catch (InterruptedException e) { + errorCount.incrementAndGet(); + throw new RuntimeException(e); + } + return null; + }); + System.out.println("PL sign in " + time); + assert time < 5000; + } + { // Thirdparty sign-ups + System.out.println("Measure thirdparty sign-ups"); + long time = measureTime(() -> { + ExecutorService es = Executors.newFixedThreadPool(100); + for (int i = 0; i < 500; i++) { + int finalI = i; + es.execute(() -> { + try { + ThirdParty.signInUp(main, "twitter", "twitterid" + finalI, "twitter" + finalI + "@example.com"); + } catch (Exception e) { + errorCount.incrementAndGet(); + throw new RuntimeException(e); + } + }); + } + es.shutdown(); + try { + es.awaitTermination(5, TimeUnit.MINUTES); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + return null; + }); + System.out.println("Thirdparty sign up " + time); + assert time < 5000; + } + { // Thirdparty sign-ins + System.out.println("Measure thirdparty sign-ins"); + long time = measureTime(() -> { + ExecutorService es = Executors.newFixedThreadPool(100); + for (int i = 0; i < 500; i++) { + int finalI = i; + es.execute(() -> { + try { + ThirdParty.signInUp(main, "twitter", "twitterid" + finalI, "twitter" + finalI + "@example.com"); + } catch (Exception e) { + errorCount.incrementAndGet(); + throw new RuntimeException(e); + } + }); + } + es.shutdown(); + try { + es.awaitTermination(5, TimeUnit.MINUTES); + } catch (InterruptedException e) { + errorCount.incrementAndGet(); + throw new RuntimeException(e); + } + return null; + }); + System.out.println("Thirdparty sign in " + time); + assert time < 5000; + } + { // Measure user pagination + long time = measureTime(() -> { + try { + long count = 0; + UserPaginationContainer users = AuthRecipe.getUsers(main, 500, "ASC", null, null, null); + while (true) { + for (AuthRecipeUserInfo user : users.users) { + count += user.loginMethods.length; + } + if (users.nextPaginationToken == null) { + break; + } + users = AuthRecipe.getUsers(main, 500, "ASC", users.nextPaginationToken, null, null); + if (count >= 500) { + break; + } + } + } catch (Exception e) { + errorCount.incrementAndGet(); + throw new RuntimeException(e); + } + return null; + }); + System.out.println("User pagination " + time); + assert time < 2000; + } + { // Measure update user metadata + long time = measureTime(() -> { + try { + UserPaginationContainer users = AuthRecipe.getUsers(main, 1, "ASC", null, null, null); + UserIdMapping.populateExternalUserIdForUsers( + new AppIdentifier(null, null), + (StorageLayer.getBaseStorage(main)), + users.users); + + AuthRecipeUserInfo user = users.users[0]; + for (int i = 0; i < 500; i++) { + UserMetadata.updateUserMetadata(main, user.getSupertokensOrExternalUserId(), new JsonObject()); + } + } catch (Exception e) { + errorCount.incrementAndGet(); + throw new RuntimeException(e); + } + return null; + }); + System.out.println("Update user metadata " + time); + } + + assertEquals(0, errorCount.get()); + } + + private static long measureTime(Supplier function) { + long startTime = System.nanoTime(); + + // Call the function + function.get(); + + long endTime = System.nanoTime(); + + // Calculate elapsed time in milliseconds + return (endTime - startTime) / 1000000; // Convert to milliseconds + } +} + diff --git a/src/test/java/io/supertokens/storage/mysql/test/StorageLayerTest.java b/src/test/java/io/supertokens/storage/mysql/test/StorageLayerTest.java index 2eae408..02d6a74 100644 --- a/src/test/java/io/supertokens/storage/mysql/test/StorageLayerTest.java +++ b/src/test/java/io/supertokens/storage/mysql/test/StorageLayerTest.java @@ -16,7 +16,7 @@ import io.supertokens.pluginInterface.multitenancy.exceptions.TenantOrAppNotFoundException; import io.supertokens.pluginInterface.totp.TOTPDevice; import io.supertokens.pluginInterface.totp.TOTPUsedCode; -import io.supertokens.pluginInterface.totp.exception.TotpNotEnabledException; +import io.supertokens.pluginInterface.totp.exception.UnknownTotpUserIdException; import io.supertokens.pluginInterface.totp.exception.UsedCodeAlreadyExistsException; import io.supertokens.pluginInterface.totp.sqlStorage.TOTPSQLStorage; import io.supertokens.storageLayer.StorageLayer; @@ -54,13 +54,13 @@ public static void insertUsedCodeUtil(TOTPSQLStorage storage, TOTPUsedCode usedC storage.insertUsedCode_Transaction(con, TenantIdentifier.BASE_TENANT, usedCode); storage.commitTransaction(con); return null; - } catch (TotpNotEnabledException | UsedCodeAlreadyExistsException | TenantOrAppNotFoundException e) { + } catch (UnknownTotpUserIdException | UsedCodeAlreadyExistsException | TenantOrAppNotFoundException e) { throw new StorageTransactionLogicException(e); } }); } catch (StorageTransactionLogicException e) { Exception actual = e.actualException; - if (actual instanceof TotpNotEnabledException || actual instanceof UsedCodeAlreadyExistsException) { + if (actual instanceof UnknownTotpUserIdException || actual instanceof UsedCodeAlreadyExistsException) { throw actual; } else { throw e; @@ -82,7 +82,7 @@ public void totpCodeLengthTest() throws Exception { long now = System.currentTimeMillis(); long nextDay = now + 1000 * 60 * 60 * 24; // 1 day from now - TOTPDevice d1 = new TOTPDevice("user", "d1", "secret", 30, 1, false); + TOTPDevice d1 = new TOTPDevice("user", "d1", "secret", 30, 1, false, System.currentTimeMillis()); storage.createDevice(TenantIdentifier.BASE_TENANT.toAppIdentifier(), d1); // Try code with length > 8 diff --git a/src/test/java/io/supertokens/storage/mysql/test/SuperTokensSaaSSecretTest.java b/src/test/java/io/supertokens/storage/mysql/test/SuperTokensSaaSSecretTest.java index 5e052e9..b551c44 100644 --- a/src/test/java/io/supertokens/storage/mysql/test/SuperTokensSaaSSecretTest.java +++ b/src/test/java/io/supertokens/storage/mysql/test/SuperTokensSaaSSecretTest.java @@ -105,7 +105,7 @@ public void testThatTenantCannotSetProtectedConfigIfSuperTokensSaaSSecretIsSet() Multitenancy.addNewOrUpdateAppOrTenant(process.main, new TenantConfig(new TenantIdentifier(null, null, "t1"), new EmailPasswordConfig(false), new ThirdPartyConfig(false, new ThirdPartyConfig.Provider[0]), new PasswordlessConfig(false), - j), true); + null, null, j), true); fail(); } catch (BadPermissionException e) { assertEquals(e.getMessage(), "Not allowed to modify DB related configs."); @@ -191,7 +191,7 @@ public void testThatTenantCannotGetProtectedConfigIfSuperTokensSaaSSecretIsSet() new TenantConfig(new TenantIdentifier(null, null, "t" + i), new EmailPasswordConfig(false), new ThirdPartyConfig(false, new ThirdPartyConfig.Provider[0]), new PasswordlessConfig(false), - j)); + null, null, j)); { JsonObject response = HttpRequestForTesting.sendJsonRequest(process.getProcess(), "", diff --git a/src/test/java/io/supertokens/storage/mysql/test/multitenancy/StorageLayerTest.java b/src/test/java/io/supertokens/storage/mysql/test/multitenancy/StorageLayerTest.java index 6f583ca..a01b71b 100644 --- a/src/test/java/io/supertokens/storage/mysql/test/multitenancy/StorageLayerTest.java +++ b/src/test/java/io/supertokens/storage/mysql/test/multitenancy/StorageLayerTest.java @@ -112,6 +112,7 @@ public void mergingTenantWithBaseConfigWorks() new TenantConfig(new TenantIdentifier("abc", null, null), new EmailPasswordConfig(false), new ThirdPartyConfig(false, new ThirdPartyConfig.Provider[0]), new PasswordlessConfig(false), + null, null, tenantConfig)}; Config.loadAllTenantConfig(process.getProcess(), tenants); @@ -156,6 +157,7 @@ public void storageInstanceIsReusedAcrossTenants() new TenantConfig(new TenantIdentifier(null, "abc", null), new EmailPasswordConfig(false), new ThirdPartyConfig(false, new ThirdPartyConfig.Provider[0]), new PasswordlessConfig(false), + null, null, tenantConfig)}; Config.loadAllTenantConfig(process.getProcess(), tenants); @@ -203,14 +205,17 @@ public void storageInstanceIsReusedAcrossTenantsComplex() new TenantConfig(new TenantIdentifier(null, "abc", null), new EmailPasswordConfig(false), new ThirdPartyConfig(false, new ThirdPartyConfig.Provider[0]), new PasswordlessConfig(false), + null, null, tenantConfig), new TenantConfig(new TenantIdentifier(null, "abc", "t1"), new EmailPasswordConfig(false), new ThirdPartyConfig(false, new ThirdPartyConfig.Provider[0]), new PasswordlessConfig(false), + null, null, tenantConfig1), new TenantConfig(new TenantIdentifier(null, null, "t2"), new EmailPasswordConfig(false), new ThirdPartyConfig(false, new ThirdPartyConfig.Provider[0]), new PasswordlessConfig(false), + null, null, tenantConfig1)}; Config.loadAllTenantConfig(process.getProcess(), tenants); @@ -275,6 +280,7 @@ public void mergingTenantWithBaseConfigWithInvalidConfigThrowsErrorWorks() new TenantConfig(new TenantIdentifier("abc", null, null), new EmailPasswordConfig(false), new ThirdPartyConfig(false, new ThirdPartyConfig.Provider[0]), new PasswordlessConfig(false), + null, null, tenantConfig)}; Config.loadAllTenantConfig(process.getProcess(), tenants); @@ -308,6 +314,7 @@ public void mergingTenantWithBaseConfigWithConflictingConfigsThrowsError() new TenantConfig(new TenantIdentifier(null, "abc", null), new EmailPasswordConfig(false), new ThirdPartyConfig(false, new ThirdPartyConfig.Provider[0]), new PasswordlessConfig(false), + null, null, tenantConfig)}; Config.loadAllTenantConfig(process.getProcess(), tenants); @@ -342,6 +349,7 @@ public void mergingDifferentConnectionPoolIdTenantWithBaseConfigWithConflictingC new TenantConfig(new TenantIdentifier(null, "abc", null), new EmailPasswordConfig(false), new ThirdPartyConfig(false, new ThirdPartyConfig.Provider[0]), new PasswordlessConfig(false), + null, null, tenantConfig)}; Config.loadAllTenantConfig(process.getProcess(), tenants); @@ -377,6 +385,7 @@ public void mergingDifferentUserPoolIdTenantWithBaseConfigWithConflictingConfigs new TenantConfig(new TenantIdentifier("abc", null, null), new EmailPasswordConfig(false), new ThirdPartyConfig(false, new ThirdPartyConfig.Provider[0]), new PasswordlessConfig(false), + null, null, tenantConfig)}; Config.loadAllTenantConfig(process.getProcess(), tenants); @@ -421,6 +430,7 @@ public void newStorageIsNotCreatedWhenSameTenantIsAdded() new TenantConfig(new TenantIdentifier(null, "abc", null), new EmailPasswordConfig(false), new ThirdPartyConfig(false, new ThirdPartyConfig.Provider[0]), new PasswordlessConfig(false), + null, null, tenantConfig)}; Config.loadAllTenantConfig(process.getProcess(), tenants); @@ -468,6 +478,7 @@ public void testDifferentWaysToGetConfigBasedOnConnectionURIAndTenantId() tenants[0] = new TenantConfig(new TenantIdentifier("c1", null, null), new EmailPasswordConfig(false), new ThirdPartyConfig(false, new ThirdPartyConfig.Provider[0]), new PasswordlessConfig(false), + null, null, tenantConfig); } @@ -479,6 +490,7 @@ public void testDifferentWaysToGetConfigBasedOnConnectionURIAndTenantId() tenants[1] = new TenantConfig(new TenantIdentifier("c1", null, "t1"), new EmailPasswordConfig(false), new ThirdPartyConfig(false, new ThirdPartyConfig.Provider[0]), new PasswordlessConfig(false), + null, null, tenantConfig); } @@ -488,6 +500,7 @@ public void testDifferentWaysToGetConfigBasedOnConnectionURIAndTenantId() tenants[2] = new TenantConfig(new TenantIdentifier(null, null, "t2"), new EmailPasswordConfig(false), new ThirdPartyConfig(false, new ThirdPartyConfig.Provider[0]), new PasswordlessConfig(false), + null, null, tenantConfig); } @@ -497,6 +510,7 @@ public void testDifferentWaysToGetConfigBasedOnConnectionURIAndTenantId() tenants[3] = new TenantConfig(new TenantIdentifier(null, null, "t1"), new EmailPasswordConfig(false), new ThirdPartyConfig(false, new ThirdPartyConfig.Provider[0]), new PasswordlessConfig(false), + null, null, tenantConfig); } @@ -558,6 +572,7 @@ public void multipleTenantsSameUserPoolAndConnectionPoolShouldWork() new TenantConfig(new TenantIdentifier(null, "abc", null), new EmailPasswordConfig(false), new ThirdPartyConfig(false, new ThirdPartyConfig.Provider[0]), new PasswordlessConfig(false), + null, null, tenantConfig)}; Config.loadAllTenantConfig(process.getProcess(), tenants); @@ -593,6 +608,7 @@ public void multipleTenantsSameUserPoolAndDifferentConnectionPoolShouldWork() new TenantConfig(new TenantIdentifier(null, "abc", null), new EmailPasswordConfig(false), new ThirdPartyConfig(false, new ThirdPartyConfig.Provider[0]), new PasswordlessConfig(false), + null, null, tenantConfig)}; Config.loadAllTenantConfig(process.getProcess(), tenants); @@ -632,6 +648,7 @@ public void testCreating50StorageLayersUsage() new EmailPasswordConfig(false), new ThirdPartyConfig(false, new ThirdPartyConfig.Provider[0]), new PasswordlessConfig(false), + null, null, config); try { Multitenancy.addNewOrUpdateAppOrTenant(process.getProcess(), new TenantIdentifier(null, null, null), @@ -685,6 +702,7 @@ public void testCantCreateTenantWithUnknownDb() new EmailPasswordConfig(false), new ThirdPartyConfig(false, new ThirdPartyConfig.Provider[0]), new PasswordlessConfig(false), + null, null, tenantConfigJson); try { @@ -723,6 +741,7 @@ public void testTenantCreationAndThenDbDownDbThrowsErrorInRecipesAndDoesntAffect new EmailPasswordConfig(true), new ThirdPartyConfig(false, new ThirdPartyConfig.Provider[0]), new PasswordlessConfig(false), + null, null, tenantConfigJson); StorageLayer.getMultitenancyStorage(process.getProcess()).createTenant(tenantConfig); @@ -794,6 +813,7 @@ public void testBadPortWithNewTenantShouldNotCauseItToWaitInput() throws Excepti new EmailPasswordConfig(false), new ThirdPartyConfig(false, new ThirdPartyConfig.Provider[0]), new PasswordlessConfig(false), + null, null, tenantConfigJson); try { diff --git a/src/test/java/io/supertokens/storage/mysql/test/multitenancy/TestUserPoolIdChangeBehaviour.java b/src/test/java/io/supertokens/storage/mysql/test/multitenancy/TestUserPoolIdChangeBehaviour.java index c13d875..77d1ad3 100644 --- a/src/test/java/io/supertokens/storage/mysql/test/multitenancy/TestUserPoolIdChangeBehaviour.java +++ b/src/test/java/io/supertokens/storage/mysql/test/multitenancy/TestUserPoolIdChangeBehaviour.java @@ -83,6 +83,7 @@ public void testUsersWorkAfterUserPoolIdChanges() throws Exception { new EmailPasswordConfig(true), new ThirdPartyConfig(true, null), new PasswordlessConfig(true), + null, null, coreConfig ), false); @@ -100,6 +101,7 @@ public void testUsersWorkAfterUserPoolIdChanges() throws Exception { new EmailPasswordConfig(true), new ThirdPartyConfig(true, null), new PasswordlessConfig(true), + null, null, coreConfig ), false); @@ -125,6 +127,7 @@ public void testUsersWorkAfterUserPoolIdChangesAndServerRestart() throws Excepti new EmailPasswordConfig(true), new ThirdPartyConfig(true, null), new PasswordlessConfig(true), + null, null, coreConfig ), false); @@ -142,6 +145,7 @@ public void testUsersWorkAfterUserPoolIdChangesAndServerRestart() throws Excepti new EmailPasswordConfig(true), new ThirdPartyConfig(true, null), new PasswordlessConfig(true), + null, null, coreConfig ), false);