diff --git a/.circleci/config.yml b/.circleci/config.yml index 02c3a060..8fd9177a 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -77,6 +77,88 @@ jobs: name: running tests command: (cd .circleci/ && ./doTests.sh) - slack/status + test-onemillionusers: + docker: + - image: rishabhpoddar/supertokens_postgresql_plugin_test + resource_class: large + steps: + - add_ssh_keys: + fingerprints: + - "14:68:18:82:73:00:e4:fc:9e:f3:6f:ce:1d:5c:6d:c4" + - checkout + - run: + name: update postgresql max_connections + command: | + sed -i 's/^#*\s*max_connections\s*=.*/max_connections = 10000/' /etc/postgresql/9.5/main/postgresql.conf + - run: + name: starting postgresql + command: | + (cd / && ./runPostgreSQL.sh) + - run: + name: create databases + command: | + psql -c "create database st0;" + psql -c "create database st1;" + psql -c "create database st2;" + psql -c "create database st3;" + psql -c "create database st4;" + psql -c "create database st5;" + psql -c "create database st6;" + psql -c "create database st7;" + psql -c "create database st8;" + psql -c "create database st9;" + psql -c "create database st10;" + psql -c "create database st11;" + psql -c "create database st12;" + psql -c "create database st13;" + psql -c "create database st14;" + psql -c "create database st15;" + psql -c "create database st16;" + psql -c "create database st17;" + psql -c "create database st18;" + psql -c "create database st19;" + psql -c "create database st20;" + psql -c "create database st21;" + psql -c "create database st22;" + psql -c "create database st23;" + psql -c "create database st24;" + psql -c "create database st25;" + psql -c "create database st26;" + psql -c "create database st27;" + psql -c "create database st28;" + psql -c "create database st29;" + psql -c "create database st30;" + psql -c "create database st31;" + psql -c "create database st32;" + psql -c "create database st33;" + psql -c "create database st34;" + psql -c "create database st35;" + psql -c "create database st36;" + psql -c "create database st37;" + psql -c "create database st38;" + psql -c "create database st39;" + psql -c "create database st40;" + psql -c "create database st41;" + psql -c "create database st42;" + psql -c "create database st43;" + psql -c "create database st44;" + psql -c "create database st45;" + psql -c "create database st46;" + psql -c "create database st47;" + psql -c "create database st48;" + psql -c "create database st49;" + psql -c "create database st50;" + - run: + name: running tests + command: (cd .circleci/ && ./doOneMillionUsersTests.sh) + - slack/status + mark-passed: + docker: + - image: rishabhpoddar/supertokens_postgresql_plugin_test + steps: + - checkout + - run: (cd .circleci && ./markPassed.sh) + - slack/status workflows: version: 2 @@ -89,4 +171,23 @@ workflows: tags: only: /dev-v[0-9]+(\.[0-9]+)*/ branches: - ignore: /.*/ \ No newline at end of file + ignore: /.*/ + - test-onemillionusers: + context: + - slack-notification + filters: + tags: + only: /dev-v[0-9]+(\.[0-9]+)*/ + branches: + ignore: /.*/ + - mark-passed: + context: + - slack-notification + filters: + tags: + only: /dev-v[0-9]+(\.[0-9]+)*/ + branches: + ignore: /.*/ + requires: + - test + - test-onemillionusers diff --git a/.circleci/doOneMillionUsersTests.sh b/.circleci/doOneMillionUsersTests.sh new file mode 100755 index 00000000..ec82508d --- /dev/null +++ b/.circleci/doOneMillionUsersTests.sh @@ -0,0 +1,135 @@ +function cleanup { + if test -f "pluginInterfaceExactVersionsOutput"; then + rm pluginInterfaceExactVersionsOutput + fi +} + +trap cleanup EXIT +cleanup + +pluginInterfaceJson=`cat ../pluginInterfaceSupported.json` +pluginInterfaceLength=`echo $pluginInterfaceJson | jq ".versions | length"` +pluginInterfaceArray=`echo $pluginInterfaceJson | jq ".versions"` +echo "got plugin interface relations" + +./getPluginInterfaceExactVersions.sh $pluginInterfaceLength "$pluginInterfaceArray" + +if [[ $? -ne 0 ]] +then + echo "all plugin interfaces found... failed. exiting!" + exit 1 +else + echo "all plugin interfaces found..." +fi + +# get plugin version +pluginVersion=`cat ../build.gradle | grep -e "version =" -e "version="` +while IFS='"' read -ra ADDR; do + counter=0 + for i in "${ADDR[@]}"; do + if [ $counter == 1 ] + then + pluginVersion=$i + fi + counter=$(($counter+1)) + done +done <<< "$pluginVersion" + +responseStatus=`curl -s -o /dev/null -w "%{http_code}" -X PUT \ + https://api.supertokens.io/0/plugin \ + -H 'Content-Type: application/json' \ + -H 'api-version: 0' \ + -d "{ + \"password\": \"$SUPERTOKENS_API_KEY\", + \"planType\":\"FREE\", + \"version\":\"$pluginVersion\", + \"pluginInterfaces\": $pluginInterfaceArray, + \"name\": \"postgresql\" +}"` +if [ $responseStatus -ne "200" ] +then + echo "failed plugin PUT API status code: $responseStatus. Exiting!" + exit 1 +fi + +someTestsRan=false +while read -u 10 line +do + if [[ $line = "" ]]; then + continue + fi + i=0 + currTag=`echo $line | jq .tag` + currTag=`echo $currTag | tr -d '"'` + + currVersion=`echo $line | jq .version` + currVersion=`echo $currVersion | tr -d '"'` + piX=$(cut -d'.' -f1 <<<"$currVersion") + piY=$(cut -d'.' -f2 <<<"$currVersion") + piVersion="$piX.$piY" + + someTestsRan=true + + response=`curl -s -X GET \ + "https://api.supertokens.io/0/plugin-interface/dependency/core/latest?password=$SUPERTOKENS_API_KEY&planType=FREE&mode=DEV&version=$piVersion" \ + -H 'api-version: 0'` + if [[ `echo $response | jq .core` == "null" ]] + then + echo "fetching latest X.Y version for core given plugin-interface X.Y version: $piVersion gave response: $response" + exit 1 + fi + coreVersionX2=$(echo $response | jq .core | tr -d '"') + + response=`curl -s -X GET \ + "https://api.supertokens.io/0/core/latest?password=$SUPERTOKENS_API_KEY&planType=FREE&mode=DEV&version=$coreVersionX2" \ + -H 'api-version: 0'` + if [[ `echo $response | jq .tag` == "null" ]] + then + echo "fetching latest X.Y.Z version for core X.Y version: $coreVersionX2 gave response: $response" + exit 1 + fi + coreVersionTag=$(echo $response | jq .tag | tr -d '"') + + cd ../../ + git clone git@github.com:supertokens/supertokens-root.git + cd supertokens-root + + update-alternatives --install "/usr/bin/java" "java" "/usr/java/jdk-15.0.1/bin/java" 2 + update-alternatives --install "/usr/bin/javac" "javac" "/usr/java/jdk-15.0.1/bin/javac" 2 + + pluginX=$(cut -d'.' -f1 <<<"$pluginVersion") + pluginY=$(cut -d'.' -f2 <<<"$pluginVersion") + echo -e "core,$coreVersionX2\nplugin-interface,$piVersion\npostgresql-plugin,$pluginX.$pluginY" > modules.txt + ./loadModules + cd supertokens-core + git checkout $coreVersionTag + cd ../supertokens-plugin-interface + git checkout $currTag + cd ../supertokens-postgresql-plugin + git checkout dev-v$pluginVersion + cd ../ + echo $SUPERTOKENS_API_KEY > apiPassword + export ONE_MILLION_USERS_TEST=1 + ./utils/setupTestEnv --cicd + ./gradlew :supertokens-postgresql-plugin:test --tests io.supertokens.storage.postgresql.test.OneMillionUsersTest + + if [[ $? -ne 0 ]] + then + cat logs/* + cd ../project/ + echo "test failed... exiting!" + exit 1 + fi + cd ../ + rm -rf supertokens-root + cd project/.circleci +done 10 { try { long now = System.currentTimeMillis(); + Connection sqlCon = (Connection) con.getConnection(); + TOTPQueries.createDevice_Transaction(this, sqlCon, tenantIdentifier.toAppIdentifier(), device); TOTPQueries.insertUsedCode_Transaction(this, - (Connection) con.getConnection(), tenantIdentifier, + sqlCon, tenantIdentifier, new TOTPUsedCode(userId, "123456", true, 1000 + now, now)); } catch (SQLException e) { throw new StorageTransactionLogicException(e); @@ -1328,25 +1327,6 @@ public int countUsersActiveSince(AppIdentifier appIdentifier, long time) throws } } - @Override - public int countUsersEnabledTotp(AppIdentifier appIdentifier) throws StorageQueryException { - try { - return ActiveUsersQueries.countUsersEnabledTotp(this, appIdentifier); - } catch (SQLException e) { - throw new StorageQueryException(e); - } - } - - @Override - public int countUsersEnabledTotpAndActiveSince(AppIdentifier appIdentifier, long time) - throws StorageQueryException { - try { - return ActiveUsersQueries.countUsersEnabledTotpAndActiveSince(this, appIdentifier, time); - } catch (SQLException e) { - throw new StorageQueryException(e); - } - } - @Override public void deleteUserActive_Transaction(TransactionConnection con, AppIdentifier appIdentifier, String userId) throws StorageQueryException { @@ -2251,10 +2231,10 @@ public boolean updateOrDeleteExternalUserIdInfo(AppIdentifier appIdentifier, Str } @Override - public HashMap getUserIdMappingForSuperTokensIds(ArrayList userIds) + public HashMap getUserIdMappingForSuperTokensIds(AppIdentifier appIdentifier, ArrayList userIds) throws StorageQueryException { try { - return UserIdMappingQueries.getUserIdMappingWithUserIds(this, userIds); + return UserIdMappingQueries.getUserIdMappingWithUserIds(this, appIdentifier, userIds); } catch (SQLException e) { throw new StorageQueryException(e); } @@ -2628,26 +2608,60 @@ public void revokeExpiredSessions() throws StorageQueryException { } // TOTP recipe: + @TestOnly @Override public void createDevice(AppIdentifier appIdentifier, TOTPDevice device) - throws StorageQueryException, DeviceAlreadyExistsException, TenantOrAppNotFoundException { + throws DeviceAlreadyExistsException, TenantOrAppNotFoundException, StorageQueryException { try { - TOTPQueries.createDevice(this, appIdentifier, device); + startTransaction(con -> { + try { + createDevice_Transaction(con, new AppIdentifier(null, null), device); + } catch (DeviceAlreadyExistsException | TenantOrAppNotFoundException e) { + throw new StorageTransactionLogicException(e); + } + return null; + }); } catch (StorageTransactionLogicException e) { - Exception actualException = e.actualException; + if (e.actualException instanceof DeviceAlreadyExistsException) { + throw (DeviceAlreadyExistsException) e.actualException; + } else if (e.actualException instanceof TenantOrAppNotFoundException) { + throw (TenantOrAppNotFoundException) e.actualException; + } else if (e.actualException instanceof StorageQueryException) { + throw (StorageQueryException) e.actualException; + } + } + } + + @Override + public TOTPDevice createDevice_Transaction(TransactionConnection con, AppIdentifier appIdentifier, TOTPDevice device) + throws StorageQueryException, DeviceAlreadyExistsException, TenantOrAppNotFoundException { + Connection sqlCon = (Connection) con.getConnection(); + try { + TOTPQueries.createDevice_Transaction(this, sqlCon, appIdentifier, device); + return device; + } catch (SQLException e) { + Exception actualException = e; if (actualException instanceof PSQLException) { ServerErrorMessage errMsg = ((PSQLException) actualException).getServerErrorMessage(); if (isPrimaryKeyError(errMsg, Config.getConfig(this).getTotpUserDevicesTable())) { - throw new DeviceAlreadyExistsException(); + throw new DeviceAlreadyExistsException(); } else if (isForeignKeyConstraintError(errMsg, Config.getConfig(this).getTotpUsersTable(), "app_id")) { - throw new TenantOrAppNotFoundException(appIdentifier); + throw new TenantOrAppNotFoundException(appIdentifier); } - } + throw new StorageQueryException(e); + } + } - throw new StorageQueryException(e.actualException); + @Override + public TOTPDevice getDeviceByName_Transaction(TransactionConnection con, AppIdentifier appIdentifier, String userId, String deviceName) throws StorageQueryException { + Connection sqlCon = (Connection) con.getConnection(); + try { + return TOTPQueries.getDeviceByName_Transaction(this, sqlCon, appIdentifier, userId, deviceName); + } catch (SQLException e) { + throw new StorageQueryException(e); } } @@ -2701,8 +2715,8 @@ public boolean removeUser(TenantIdentifier tenantIdentifier, String userId) @Override public void updateDeviceName(AppIdentifier appIdentifier, String userId, String oldDeviceName, String newDeviceName) - throws StorageQueryException, DeviceAlreadyExistsException, - UnknownDeviceException { + throws StorageQueryException, + UnknownDeviceException, DeviceAlreadyExistsException { try { int updatedCount = TOTPQueries.updateDeviceName(this, appIdentifier, userId, oldDeviceName, newDeviceName); if (updatedCount == 0) { @@ -2712,7 +2726,7 @@ public void updateDeviceName(AppIdentifier appIdentifier, String userId, String if (e instanceof PSQLException) { ServerErrorMessage errMsg = ((PSQLException) e).getServerErrorMessage(); if (isPrimaryKeyError(errMsg, Config.getConfig(this).getTotpUserDevicesTable())) { - throw new DeviceAlreadyExistsException(); + throw new DeviceAlreadyExistsException(); } } throw new StorageQueryException(e); @@ -2743,7 +2757,7 @@ public TOTPDevice[] getDevices_Transaction(TransactionConnection con, AppIdentif @Override public void insertUsedCode_Transaction(TransactionConnection con, TenantIdentifier tenantIdentifier, TOTPUsedCode usedCodeObj) - throws StorageQueryException, TotpNotEnabledException, UsedCodeAlreadyExistsException, + throws StorageQueryException, UnknownTotpUserIdException, UsedCodeAlreadyExistsException, TenantOrAppNotFoundException { Connection sqlCon = (Connection) con.getConnection(); try { @@ -2755,7 +2769,7 @@ public void insertUsedCode_Transaction(TransactionConnection con, TenantIdentifi throw new UsedCodeAlreadyExistsException(); } else if (isForeignKeyConstraintError(err, Config.getConfig(this).getTotpUsedCodesTable(), "user_id")) { - throw new TotpNotEnabledException(); + throw new UnknownTotpUserIdException(); } else if (isForeignKeyConstraintError(err, Config.getConfig(this).getTotpUsedCodesTable(), "tenant_id")) { throw new TenantOrAppNotFoundException(tenantIdentifier); } @@ -3001,6 +3015,25 @@ public UserIdMapping[] getUserIdMapping_Transaction(TransactionConnection con, A } } + @Override + public int getUsersCountWithMoreThanOneLoginMethodOrTOTPEnabled(AppIdentifier appIdentifier) throws StorageQueryException { + try { + return GeneralQueries.getUsersCountWithMoreThanOneLoginMethodOrTOTPEnabled(this, appIdentifier); + } catch (SQLException e) { + throw new StorageQueryException(e); + } + } + + @Override + public int countUsersThatHaveMoreThanOneLoginMethodOrTOTPEnabledAndActiveSince(AppIdentifier appIdentifier, long sinceTime) throws StorageQueryException { + try { + return ActiveUsersQueries.countUsersThatHaveMoreThanOneLoginMethodOrTOTPEnabledAndActiveSince(this, + appIdentifier, sinceTime); + } catch (SQLException e) { + throw new StorageQueryException(e); + } + } + @TestOnly public int getDbActivityCount(String dbname) throws SQLException, StorageQueryException { String QUERY = "SELECT COUNT(*) as c FROM pg_stat_activity WHERE datname = ?;"; diff --git a/src/main/java/io/supertokens/storage/postgresql/config/PostgreSQLConfig.java b/src/main/java/io/supertokens/storage/postgresql/config/PostgreSQLConfig.java index 2c464849..e0a0c682 100644 --- a/src/main/java/io/supertokens/storage/postgresql/config/PostgreSQLConfig.java +++ b/src/main/java/io/supertokens/storage/postgresql/config/PostgreSQLConfig.java @@ -190,6 +190,14 @@ public String getTenantConfigsTable() { return addSchemaAndPrefixToTableName("tenant_configs"); } + public String getTenantFirstFactorsTable() { + return addSchemaAndPrefixToTableName("tenant_first_factors"); + } + + public String getTenantRequiredSecondaryFactorsTable() { + return addSchemaAndPrefixToTableName("tenant_required_secondary_factors"); + } + public String getTenantThirdPartyProvidersTable() { return addSchemaAndPrefixToTableName("tenant_thirdparty_providers"); } diff --git a/src/main/java/io/supertokens/storage/postgresql/queries/ActiveUsersQueries.java b/src/main/java/io/supertokens/storage/postgresql/queries/ActiveUsersQueries.java index d40c08f1..3a39c384 100644 --- a/src/main/java/io/supertokens/storage/postgresql/queries/ActiveUsersQueries.java +++ b/src/main/java/io/supertokens/storage/postgresql/queries/ActiveUsersQueries.java @@ -11,6 +11,7 @@ import static io.supertokens.storage.postgresql.QueryExecutorTemplate.execute; import static io.supertokens.storage.postgresql.QueryExecutorTemplate.update; +import static io.supertokens.storage.postgresql.config.Config.getConfig; public class ActiveUsersQueries { static String getQueryToCreateUserLastActiveTable(Start start) { @@ -51,9 +52,10 @@ public static int countUsersActiveSince(Start start, AppIdentifier appIdentifier public static int countUsersActiveSinceAndHasMoreThanOneLoginMethod(Start start, AppIdentifier appIdentifier, long sinceTime) throws SQLException, StorageQueryException { + // TODO: Active users are present only on public tenant and MFA users may be present on different storages String QUERY = "SELECT count(1) as c FROM (" + " SELECT count(user_id) as num_login_methods, app_id, primary_or_recipe_user_id" - + " FROM " + Config.getConfig(start).getUsersTable() + + " FROM " + Config.getConfig(start).getAppIdToUserIdTable() + " WHERE primary_or_recipe_user_id IN (" + " SELECT user_id FROM " + Config.getConfig(start).getUserLastActiveTable() + " WHERE app_id = ? AND last_active_time >= ?" @@ -71,40 +73,6 @@ public static int countUsersActiveSinceAndHasMoreThanOneLoginMethod(Start start, }); } - public static int countUsersEnabledTotp(Start start, AppIdentifier appIdentifier) - throws SQLException, StorageQueryException { - String QUERY = "SELECT COUNT(*) as total FROM " + Config.getConfig(start).getTotpUsersTable() - + " WHERE app_id = ?"; - - return execute(start, QUERY, pst -> { - pst.setString(1, appIdentifier.getAppId()); - }, result -> { - if (result.next()) { - return result.getInt("total"); - } - return 0; - }); - } - - public static int countUsersEnabledTotpAndActiveSince(Start start, AppIdentifier appIdentifier, long sinceTime) - throws SQLException, StorageQueryException { - String QUERY = - "SELECT COUNT(*) as total FROM " + Config.getConfig(start).getTotpUsersTable() + " AS totp_users " - + "INNER JOIN " + Config.getConfig(start).getUserLastActiveTable() + " AS user_last_active " - + "ON totp_users.user_id = user_last_active.user_id " - + "WHERE user_last_active.app_id = ? AND user_last_active.last_active_time >= ?"; - - return execute(start, QUERY, pst -> { - pst.setString(1, appIdentifier.getAppId()); - pst.setLong(2, sinceTime); - }, result -> { - if (result.next()) { - return result.getInt("total"); - } - return 0; - }); - } - public static int updateUserLastActive(Start start, AppIdentifier appIdentifier, String userId) throws SQLException, StorageQueryException { String QUERY = "INSERT INTO " + Config.getConfig(start).getUserLastActiveTable() @@ -152,4 +120,41 @@ public static void deleteUserActive_Transaction(Connection con, Start start, App pst.setString(2, userId); }); } + + public static int countUsersThatHaveMoreThanOneLoginMethodOrTOTPEnabledAndActiveSince(Start start, AppIdentifier appIdentifier, long sinceTime) + throws SQLException, StorageQueryException { + // TODO: Active users are present only on public tenant and MFA users may be present on different storages + String QUERY = + "SELECT COUNT (DISTINCT user_id) as c FROM (" + + " (" // users with more than one login method + + " SELECT primary_or_recipe_user_id AS user_id FROM (" + + " SELECT COUNT(user_id) as num_login_methods, app_id, primary_or_recipe_user_id" + + " FROM " + getConfig(start).getAppIdToUserIdTable() + + " WHERE app_id = ? AND primary_or_recipe_user_id IN (" + + " SELECT user_id FROM " + getConfig(start).getUserLastActiveTable() + + " WHERE app_id = ? AND last_active_time >= ?" + + " )" + + " GROUP BY (app_id, primary_or_recipe_user_id)" + + " ) AS nloginmethods" + + " WHERE num_login_methods > 1" + + " ) UNION (" // TOTP users + + " SELECT user_id FROM " + getConfig(start).getTotpUsersTable() + + " WHERE app_id = ? AND user_id IN (" + + " SELECT user_id FROM " + getConfig(start).getUserLastActiveTable() + + " WHERE app_id = ? AND last_active_time >= ?" + + " )" + + " )" + + ") AS all_users"; + + return execute(start, QUERY, pst -> { + pst.setString(1, appIdentifier.getAppId()); + pst.setString(2, appIdentifier.getAppId()); + pst.setLong(3, sinceTime); + pst.setString(4, appIdentifier.getAppId()); + pst.setString(5, appIdentifier.getAppId()); + pst.setLong(6, sinceTime); + }, result -> { + return result.next() ? result.getInt("c") : 0; + }); + } } diff --git a/src/main/java/io/supertokens/storage/postgresql/queries/EmailVerificationQueries.java b/src/main/java/io/supertokens/storage/postgresql/queries/EmailVerificationQueries.java index ff9fc950..6fd00660 100644 --- a/src/main/java/io/supertokens/storage/postgresql/queries/EmailVerificationQueries.java +++ b/src/main/java/io/supertokens/storage/postgresql/queries/EmailVerificationQueries.java @@ -271,7 +271,7 @@ public static List isEmailVerified_transaction(Start start, Connection s // calculating the verified emails HashMap supertokensUserIdToExternalUserIdMap = UserIdMappingQueries.getUserIdMappingWithUserIds_Transaction(start, - sqlCon, supertokensUserIds); + sqlCon, appIdentifier, supertokensUserIds); HashMap externalUserIdToSupertokensUserIdMap = new HashMap<>(); List supertokensOrExternalUserIdsToQuery = new ArrayList<>(); @@ -340,7 +340,7 @@ public static List isEmailVerified(Start start, AppIdentifier appIdentif // We have external user id stored in the email verification table, so we need to fetch the mapped userids for // calculating the verified emails HashMap supertokensUserIdToExternalUserIdMap = UserIdMappingQueries.getUserIdMappingWithUserIds(start, - supertokensUserIds); + appIdentifier, supertokensUserIds); HashMap externalUserIdToSupertokensUserIdMap = new HashMap<>(); List supertokensOrExternalUserIdsToQuery = new ArrayList<>(); for (String userId : supertokensUserIds) { diff --git a/src/main/java/io/supertokens/storage/postgresql/queries/GeneralQueries.java b/src/main/java/io/supertokens/storage/postgresql/queries/GeneralQueries.java index 729b8a51..8bc2d561 100644 --- a/src/main/java/io/supertokens/storage/postgresql/queries/GeneralQueries.java +++ b/src/main/java/io/supertokens/storage/postgresql/queries/GeneralQueries.java @@ -327,6 +327,25 @@ public static void createTablesIfNotExists(Start start) throws SQLException, Sto NO_OP_SETTER); } + if (!doesTableExists(start, Config.getConfig(start).getTenantFirstFactorsTable())) { + getInstance(start).addState(CREATING_NEW_TABLE, null); + update(start, MultitenancyQueries.getQueryToCreateFirstFactorsTable(start), NO_OP_SETTER); + + // index + update(start, MultitenancyQueries.getQueryToCreateTenantIdIndexForFirstFactorsTable(start), + NO_OP_SETTER); + } + + if (!doesTableExists(start, Config.getConfig(start).getTenantRequiredSecondaryFactorsTable())) { + getInstance(start).addState(CREATING_NEW_TABLE, null); + update(start, MultitenancyQueries.getQueryToCreateRequiredSecondaryFactorsTable(start), NO_OP_SETTER); + + // index + update(start, + MultitenancyQueries.getQueryToCreateTenantIdIndexForRequiredSecondaryFactorsTable(start), + NO_OP_SETTER); + } + if (!doesTableExists(start, Config.getConfig(start).getTenantThirdPartyProviderClientsTable())) { getInstance(start).addState(CREATING_NEW_TABLE, null); update(start, MultitenancyQueries.getQueryToCreateTenantThirdPartyProviderClientsTable(start), @@ -568,6 +587,8 @@ public static void deleteAllTables(Start start) throws SQLException, StorageQuer + getConfig(start).getUserIdMappingTable() + "," + getConfig(start).getUsersTable() + "," + getConfig(start).getAccessTokenSigningKeysTable() + "," + + getConfig(start).getTenantFirstFactorsTable() + "," + + getConfig(start).getTenantRequiredSecondaryFactorsTable() + "," + getConfig(start).getTenantConfigsTable() + "," + getConfig(start).getTenantThirdPartyProvidersTable() + "," + getConfig(start).getTenantThirdPartyProviderClientsTable() + "," @@ -590,7 +611,8 @@ public static void deleteAllTables(Start start) throws SQLException, StorageQuer + getConfig(start).getUserRolesTable() + "," + getConfig(start).getDashboardUsersTable() + "," + getConfig(start).getDashboardSessionsTable() + "," - + getConfig(start).getTotpUsedCodesTable() + "," + getConfig(start).getTotpUserDevicesTable() + "," + + getConfig(start).getTotpUsedCodesTable() + "," + + getConfig(start).getTotpUserDevicesTable() + "," + getConfig(start).getTotpUsersTable(); update(start, DROP_QUERY, NO_OP_SETTER); } @@ -1692,7 +1714,7 @@ public static int getUsersCountWithMoreThanOneLoginMethod(Start start, AppIdenti throws SQLException, StorageQueryException { String QUERY = "SELECT COUNT (1) as c FROM (" + " SELECT COUNT(user_id) as num_login_methods " - + " FROM " + getConfig(start).getUsersTable() + + " FROM " + getConfig(start).getAppIdToUserIdTable() + " WHERE app_id = ? " + " GROUP BY (app_id, primary_or_recipe_user_id) " + ") as nloginmethods WHERE num_login_methods > 1"; @@ -1704,6 +1726,32 @@ public static int getUsersCountWithMoreThanOneLoginMethod(Start start, AppIdenti }); } + public static int getUsersCountWithMoreThanOneLoginMethodOrTOTPEnabled(Start start, AppIdentifier appIdentifier) + throws SQLException, StorageQueryException { + String QUERY = + "SELECT COUNT (DISTINCT user_id) as c FROM (" + + " (" // Users with number of login methods > 1 + + " SELECT primary_or_recipe_user_id AS user_id FROM (" + + " SELECT COUNT(user_id) as num_login_methods, app_id, primary_or_recipe_user_id" + + " FROM " + getConfig(start).getAppIdToUserIdTable() + + " WHERE app_id = ? " + + " GROUP BY (app_id, primary_or_recipe_user_id)" + + " ) AS nloginmethods" + + " WHERE num_login_methods > 1" + + " ) UNION (" // TOTP users + + " SELECT user_id FROM " + getConfig(start).getTotpUsersTable() + + " WHERE app_id = ?" + + " )" + + ") AS all_users"; + + return execute(start, QUERY, pst -> { + pst.setString(1, appIdentifier.getAppId()); + pst.setString(2, appIdentifier.getAppId()); + }, result -> { + return result.next() ? result.getInt("c") : 0; + }); + } + public static boolean checkIfUsesAccountLinking(Start start, AppIdentifier appIdentifier) throws SQLException, StorageQueryException { String QUERY = "SELECT 1 FROM " diff --git a/src/main/java/io/supertokens/storage/postgresql/queries/MultitenancyQueries.java b/src/main/java/io/supertokens/storage/postgresql/queries/MultitenancyQueries.java index 592cdbd0..e4801d45 100644 --- a/src/main/java/io/supertokens/storage/postgresql/queries/MultitenancyQueries.java +++ b/src/main/java/io/supertokens/storage/postgresql/queries/MultitenancyQueries.java @@ -16,13 +16,13 @@ package io.supertokens.storage.postgresql.queries; -import io.supertokens.pluginInterface.emailpassword.exceptions.UnknownUserIdException; import io.supertokens.pluginInterface.exceptions.StorageQueryException; import io.supertokens.pluginInterface.exceptions.StorageTransactionLogicException; import io.supertokens.pluginInterface.multitenancy.*; import io.supertokens.pluginInterface.multitenancy.exceptions.TenantOrAppNotFoundException; import io.supertokens.storage.postgresql.Start; import io.supertokens.storage.postgresql.config.Config; +import io.supertokens.storage.postgresql.queries.multitenancy.MfaSqlHelper; import io.supertokens.storage.postgresql.queries.multitenancy.TenantConfigSQLHelper; import io.supertokens.storage.postgresql.queries.multitenancy.ThirdPartyProviderClientSQLHelper; import io.supertokens.storage.postgresql.queries.multitenancy.ThirdPartyProviderSQLHelper; @@ -119,6 +119,53 @@ public static String getQueryToCreateThirdPartyIdIndexForTenantThirdPartyProvide + getConfig(start).getTenantThirdPartyProviderClientsTable() + " (connection_uri_domain, app_id, tenant_id, third_party_id);"; } + public static String getQueryToCreateFirstFactorsTable(Start start) { + String schema = Config.getConfig(start).getTableSchema(); + String tableName = Config.getConfig(start).getTenantFirstFactorsTable(); + // @formatter:off + return "CREATE TABLE IF NOT EXISTS " + tableName + " (" + + "connection_uri_domain VARCHAR(256) DEFAULT ''," + + "app_id VARCHAR(64) DEFAULT 'public'," + + "tenant_id VARCHAR(64) DEFAULT 'public'," + + "factor_id VARCHAR(128)," + + "CONSTRAINT " + Utils.getConstraintName(schema, tableName, null, "pkey") + + " PRIMARY KEY (connection_uri_domain, app_id, tenant_id, factor_id)," + + "CONSTRAINT " + Utils.getConstraintName(schema, tableName, "tenant_id", "fkey") + + " FOREIGN KEY (connection_uri_domain, app_id, tenant_id)" + + " REFERENCES " + Config.getConfig(start).getTenantConfigsTable() + " (connection_uri_domain, app_id, tenant_id) ON DELETE CASCADE" + + ");"; + // @formatter:on + } + + public static String getQueryToCreateTenantIdIndexForFirstFactorsTable(Start start) { + return "CREATE INDEX IF NOT EXISTS tenant_first_factors_tenant_id_index ON " + + getConfig(start).getTenantFirstFactorsTable() + " (connection_uri_domain, app_id, tenant_id);"; + } + + public static String getQueryToCreateRequiredSecondaryFactorsTable(Start start) { + String schema = Config.getConfig(start).getTableSchema(); + String tableName = Config.getConfig(start).getTenantRequiredSecondaryFactorsTable(); + // @formatter:off + return "CREATE TABLE IF NOT EXISTS " + tableName + " (" + + "connection_uri_domain VARCHAR(256) DEFAULT ''," + + "app_id VARCHAR(64) DEFAULT 'public'," + + "tenant_id VARCHAR(64) DEFAULT 'public'," + + "factor_id VARCHAR(128)," + + "CONSTRAINT " + Utils.getConstraintName(schema, tableName, null, "pkey") + + " PRIMARY KEY (connection_uri_domain, app_id, tenant_id, factor_id)," + + "CONSTRAINT " + Utils.getConstraintName(schema, tableName, "tenant_id", "fkey") + + " FOREIGN KEY (connection_uri_domain, app_id, tenant_id)" + + " REFERENCES " + Config.getConfig(start).getTenantConfigsTable() + " (connection_uri_domain, app_id, tenant_id) ON DELETE CASCADE" + + ");"; + // @formatter:on + } + + public static String getQueryToCreateTenantIdIndexForRequiredSecondaryFactorsTable(Start start) { + return "CREATE INDEX IF NOT EXISTS tenant_default_required_factor_ids_tenant_id_index ON " + + getConfig(start).getTenantRequiredSecondaryFactorsTable() + " (connection_uri_domain, app_id, tenant_id);"; + } + + private static void executeCreateTenantQueries(Start start, Connection sqlCon, TenantConfig tenantConfig) throws SQLException, StorageQueryException { @@ -131,6 +178,9 @@ private static void executeCreateTenantQueries(Start start, Connection sqlCon, T ThirdPartyProviderClientSQLHelper.create(start, sqlCon, tenantConfig, provider, providerClient); } } + + MfaSqlHelper.createFirstFactors(start, sqlCon, tenantConfig.tenantIdentifier, tenantConfig.firstFactors); + MfaSqlHelper.createRequiredSecondaryFactors(start, sqlCon, tenantConfig.tenantIdentifier, tenantConfig.requiredSecondaryFactors); } public static void createTenantConfig(Start start, TenantConfig tenantConfig) throws StorageQueryException, StorageTransactionLogicException { @@ -209,7 +259,13 @@ public static TenantConfig[] getAllTenants(Start start) throws StorageQueryExcep // Map (tenantIdentifier) -> thirdPartyId -> provider HashMap> providerMap = ThirdPartyProviderSQLHelper.selectAll(start, providerClientsMap); - return TenantConfigSQLHelper.selectAll(start, providerMap); + // Map (tenantIdentifier) -> firstFactors + HashMap firstFactorsMap = MfaSqlHelper.selectAllFirstFactors(start); + + // Map (tenantIdentifier) -> requiredSecondaryFactors + HashMap requiredSecondaryFactorsMap = MfaSqlHelper.selectAllRequiredSecondaryFactors(start); + + return TenantConfigSQLHelper.selectAll(start, providerMap, firstFactorsMap, requiredSecondaryFactorsMap); } catch (SQLException throwables) { throw new StorageQueryException(throwables); } diff --git a/src/main/java/io/supertokens/storage/postgresql/queries/SessionQueries.java b/src/main/java/io/supertokens/storage/postgresql/queries/SessionQueries.java index d6685638..0fe56e4d 100644 --- a/src/main/java/io/supertokens/storage/postgresql/queries/SessionQueries.java +++ b/src/main/java/io/supertokens/storage/postgresql/queries/SessionQueries.java @@ -166,18 +166,19 @@ public static SessionInfo getSessionInfo_Transaction(Start start, Connection con public static void updateSessionInfo_Transaction(Start start, Connection con, TenantIdentifier tenantIdentifier, String sessionHandle, - String refreshTokenHash2, long expiry) + String refreshTokenHash2, long expiry, boolean useStaticKey) throws SQLException, StorageQueryException { String QUERY = "UPDATE " + getConfig(start).getSessionInfoTable() - + " SET refresh_token_hash_2 = ?, expires_at = ?" + + " SET refresh_token_hash_2 = ?, expires_at = ?, use_static_key = ?" + " WHERE app_id = ? AND tenant_id = ? AND session_handle = ?"; update(con, QUERY, pst -> { pst.setString(1, refreshTokenHash2); pst.setLong(2, expiry); - pst.setString(3, tenantIdentifier.getAppId()); - pst.setString(4, tenantIdentifier.getTenantId()); - pst.setString(5, sessionHandle); + pst.setBoolean(3, useStaticKey); + pst.setString(4, tenantIdentifier.getAppId()); + pst.setString(5, tenantIdentifier.getTenantId()); + pst.setString(6, sessionHandle); }); } diff --git a/src/main/java/io/supertokens/storage/postgresql/queries/TOTPQueries.java b/src/main/java/io/supertokens/storage/postgresql/queries/TOTPQueries.java index dad5e52d..8a8269d5 100644 --- a/src/main/java/io/supertokens/storage/postgresql/queries/TOTPQueries.java +++ b/src/main/java/io/supertokens/storage/postgresql/queries/TOTPQueries.java @@ -54,6 +54,7 @@ public static String getQueryToCreateUserDevicesTable(Start start) { + "period INTEGER NOT NULL," + "skew INTEGER NOT NULL," + "verified BOOLEAN NOT NULL," + + "created_at BIGINT," + "CONSTRAINT " + Utils.getConstraintName(schema, tableName, null, "pkey") + " PRIMARY KEY (app_id, user_id, device_name)," + "CONSTRAINT " + Utils.getConstraintName(schema, tableName, "user_id", "fkey") @@ -121,7 +122,7 @@ private static int insertUser_Transaction(Start start, Connection con, AppIdenti private static int insertDevice_Transaction(Start start, Connection con, AppIdentifier appIdentifier, TOTPDevice device) throws SQLException, StorageQueryException { String QUERY = "INSERT INTO " + Config.getConfig(start).getTotpUserDevicesTable() - + " (app_id, user_id, device_name, secret_key, period, skew, verified) VALUES (?, ?, ?, ?, ?, ?, ?)"; + + " (app_id, user_id, device_name, secret_key, period, skew, verified, created_at) VALUES (?, ?, ?, ?, ?, ?, ?, ?)"; return update(con, QUERY, pst -> { pst.setString(1, appIdentifier.getAppId()); @@ -131,25 +132,31 @@ private static int insertDevice_Transaction(Start start, Connection con, AppIden pst.setInt(5, device.period); pst.setInt(6, device.skew); pst.setBoolean(7, device.verified); + pst.setLong(8, device.createdAt); }); } - public static void createDevice(Start start, AppIdentifier appIdentifier, TOTPDevice device) - throws StorageQueryException, StorageTransactionLogicException { - start.startTransaction(con -> { - Connection sqlCon = (Connection) con.getConnection(); - - try { - insertUser_Transaction(start, sqlCon, appIdentifier, device.userId); - insertDevice_Transaction(start, sqlCon, appIdentifier, device); - sqlCon.commit(); - } catch (SQLException e) { - throw new StorageTransactionLogicException(e); - } + public static void createDevice_Transaction(Start start, Connection sqlCon, AppIdentifier appIdentifier, TOTPDevice device) + throws SQLException, StorageQueryException { + insertUser_Transaction(start, sqlCon, appIdentifier, device.userId); + insertDevice_Transaction(start, sqlCon, appIdentifier, device); + } + + public static TOTPDevice getDeviceByName_Transaction(Start start, Connection sqlCon, AppIdentifier appIdentifier, String userId, String deviceName) + throws SQLException, StorageQueryException { + String QUERY = "SELECT * FROM " + Config.getConfig(start).getTotpUserDevicesTable() + + " WHERE app_id = ? AND user_id = ? AND device_name = ? FOR UPDATE;"; + return execute(sqlCon, QUERY, pst -> { + pst.setString(1, appIdentifier.getAppId()); + pst.setString(2, userId); + pst.setString(3, deviceName); + }, result -> { + if (result.next()) { + return TOTPDeviceRowMapper.getInstance().map(result); + } return null; }); - return; } public static int markDeviceAsVerified(Start start, AppIdentifier appIdentifier, String userId, String deviceName) @@ -321,7 +328,8 @@ public TOTPDevice map(ResultSet result) throws SQLException { result.getString("secret_key"), result.getInt("period"), result.getInt("skew"), - result.getBoolean("verified")); + result.getBoolean("verified"), + result.getLong("created_at")); } } diff --git a/src/main/java/io/supertokens/storage/postgresql/queries/UserIdMappingQueries.java b/src/main/java/io/supertokens/storage/postgresql/queries/UserIdMappingQueries.java index 24f4fab7..a2388765 100644 --- a/src/main/java/io/supertokens/storage/postgresql/queries/UserIdMappingQueries.java +++ b/src/main/java/io/supertokens/storage/postgresql/queries/UserIdMappingQueries.java @@ -128,7 +128,8 @@ public static UserIdMapping[] getUserIdMappingWithEitherSuperTokensUserIdOrExter } - public static HashMap getUserIdMappingWithUserIds(Start start, List userIds) + public static HashMap getUserIdMappingWithUserIds(Start start, + AppIdentifier appIdentifier, List userIds) throws SQLException, StorageQueryException { if (userIds.size() == 0) { @@ -137,7 +138,8 @@ public static HashMap getUserIdMappingWithUserIds(Start start, L // No need to filter based on tenantId because the id list is already filtered for a tenant StringBuilder QUERY = new StringBuilder( - "SELECT * FROM " + Config.getConfig(start).getUserIdMappingTable() + " WHERE supertokens_user_id IN ("); + "SELECT * FROM " + Config.getConfig(start).getUserIdMappingTable() + " WHERE app_id = ? AND " + + "supertokens_user_id IN ("); for (int i = 0; i < userIds.size(); i++) { QUERY.append("?"); if (i != userIds.size() - 1) { @@ -147,9 +149,10 @@ public static HashMap getUserIdMappingWithUserIds(Start start, L } QUERY.append(")"); return execute(start, QUERY.toString(), pst -> { + pst.setString(1, appIdentifier.getAppId()); for (int i = 0; i < userIds.size(); i++) { - // i+1 cause this starts with 1 and not 0 - pst.setString(i + 1, userIds.get(i)); + // i+2 cause this starts with 1 and not 0, and 1 is appId + pst.setString(i + 2, userIds.get(i)); } }, result -> { HashMap userIdMappings = new HashMap<>(); @@ -161,7 +164,9 @@ public static HashMap getUserIdMappingWithUserIds(Start start, L }); } - public static HashMap getUserIdMappingWithUserIds_Transaction(Start start, Connection sqlCon, List userIds) + public static HashMap getUserIdMappingWithUserIds_Transaction(Start start, Connection sqlCon, + AppIdentifier appIdentifier, + List userIds) throws SQLException, StorageQueryException { if (userIds.size() == 0) { @@ -170,7 +175,8 @@ public static HashMap getUserIdMappingWithUserIds_Transaction(St // No need to filter based on tenantId because the id list is already filtered for a tenant StringBuilder QUERY = new StringBuilder( - "SELECT * FROM " + Config.getConfig(start).getUserIdMappingTable() + " WHERE supertokens_user_id IN ("); + "SELECT * FROM " + Config.getConfig(start).getUserIdMappingTable() + " WHERE app_id = ? AND " + + "supertokens_user_id IN ("); for (int i = 0; i < userIds.size(); i++) { QUERY.append("?"); if (i != userIds.size() - 1) { @@ -180,9 +186,10 @@ public static HashMap getUserIdMappingWithUserIds_Transaction(St } QUERY.append(")"); return execute(sqlCon, QUERY.toString(), pst -> { + pst.setString(1, appIdentifier.getAppId()); for (int i = 0; i < userIds.size(); i++) { - // i+1 cause this starts with 1 and not 0 - pst.setString(i + 1, userIds.get(i)); + // i+2 cause this starts with 1 and not 0, and 1 is appId + pst.setString(i + 2, userIds.get(i)); } }, result -> { HashMap userIdMappings = new HashMap<>(); diff --git a/src/main/java/io/supertokens/storage/postgresql/queries/multitenancy/MfaSqlHelper.java b/src/main/java/io/supertokens/storage/postgresql/queries/multitenancy/MfaSqlHelper.java new file mode 100644 index 00000000..b5abf91d --- /dev/null +++ b/src/main/java/io/supertokens/storage/postgresql/queries/multitenancy/MfaSqlHelper.java @@ -0,0 +1,116 @@ +/* + * Copyright (c) 2023, VRAI Labs and/or its affiliates. All rights reserved. + * + * This software is licensed under the Apache License, Version 2.0 (the + * "License") as published by the Apache Software Foundation. + * + * You may not use this file except in compliance with the License. You may + * obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package io.supertokens.storage.postgresql.queries.multitenancy; + +import io.supertokens.pluginInterface.exceptions.StorageQueryException; +import io.supertokens.pluginInterface.multitenancy.TenantIdentifier; +import io.supertokens.storage.postgresql.Start; + +import java.sql.Connection; +import java.sql.SQLException; +import java.util.*; + +import static io.supertokens.storage.postgresql.QueryExecutorTemplate.execute; +import static io.supertokens.storage.postgresql.QueryExecutorTemplate.update; +import static io.supertokens.storage.postgresql.config.Config.getConfig; + +public class MfaSqlHelper { + public static HashMap selectAllFirstFactors(Start start) + throws SQLException, StorageQueryException { + String QUERY = "SELECT connection_uri_domain, app_id, tenant_id, factor_id FROM " + + getConfig(start).getTenantFirstFactorsTable() + ";"; + return execute(start, QUERY, pst -> {}, result -> { + HashMap> firstFactors = new HashMap<>(); + + while (result.next()) { + TenantIdentifier tenantIdentifier = new TenantIdentifier(result.getString("connection_uri_domain"), result.getString("app_id"), result.getString("tenant_id")); + if (!firstFactors.containsKey(tenantIdentifier)) { + firstFactors.put(tenantIdentifier, new ArrayList<>()); + } + + firstFactors.get(tenantIdentifier).add(result.getString("factor_id")); + } + + HashMap finalResult = new HashMap<>(); + for (TenantIdentifier tenantIdentifier : firstFactors.keySet()) { + finalResult.put(tenantIdentifier, firstFactors.get(tenantIdentifier).toArray(new String[0])); + } + + return finalResult; + }); + } + + public static HashMap selectAllRequiredSecondaryFactors(Start start) + throws SQLException, StorageQueryException { + String QUERY = "SELECT connection_uri_domain, app_id, tenant_id, factor_id FROM " + + getConfig(start).getTenantRequiredSecondaryFactorsTable() + ";"; + return execute(start, QUERY, pst -> {}, result -> { + HashMap> defaultRequiredFactors = new HashMap<>(); + + while (result.next()) { + TenantIdentifier tenantIdentifier = new TenantIdentifier(result.getString("connection_uri_domain"), + result.getString("app_id"), result.getString("tenant_id")); + if (!defaultRequiredFactors.containsKey(tenantIdentifier)) { + defaultRequiredFactors.put(tenantIdentifier, new ArrayList<>()); + } + + defaultRequiredFactors.get(tenantIdentifier).add(result.getString("factor_id")); + } + + HashMap finalResult = new HashMap<>(); + for (TenantIdentifier tenantIdentifier : defaultRequiredFactors.keySet()) { + finalResult.put(tenantIdentifier, defaultRequiredFactors.get(tenantIdentifier).toArray(new String[0])); + } + + return finalResult; + }); + } + + public static void createFirstFactors(Start start, Connection sqlCon, TenantIdentifier tenantIdentifier, String[] firstFactors) + throws SQLException, StorageQueryException { + if (firstFactors == null || firstFactors.length == 0) { + return; + } + + String QUERY = "INSERT INTO " + getConfig(start).getTenantFirstFactorsTable() + "(connection_uri_domain, app_id, tenant_id, factor_id) VALUES (?, ?, ?, ?);"; + for (String factorId : new HashSet<>(Arrays.asList(firstFactors))) { + update(sqlCon, QUERY, pst -> { + pst.setString(1, tenantIdentifier.getConnectionUriDomain()); + pst.setString(2, tenantIdentifier.getAppId()); + pst.setString(3, tenantIdentifier.getTenantId()); + pst.setString(4, factorId); + }); + } + } + + public static void createRequiredSecondaryFactors(Start start, Connection sqlCon, TenantIdentifier tenantIdentifier, String[] requiredSecondaryFactors) + throws SQLException, StorageQueryException { + if (requiredSecondaryFactors == null || requiredSecondaryFactors.length == 0) { + return; + } + + String QUERY = "INSERT INTO " + getConfig(start).getTenantRequiredSecondaryFactorsTable() + "(connection_uri_domain, app_id, tenant_id, factor_id) VALUES (?, ?, ?, ?);"; + for (String factorId : requiredSecondaryFactors) { + update(sqlCon, QUERY, pst -> { + pst.setString(1, tenantIdentifier.getConnectionUriDomain()); + pst.setString(2, tenantIdentifier.getAppId()); + pst.setString(3, tenantIdentifier.getTenantId()); + pst.setString(4, factorId); + }); + } + } +} diff --git a/src/main/java/io/supertokens/storage/postgresql/queries/multitenancy/TenantConfigSQLHelper.java b/src/main/java/io/supertokens/storage/postgresql/queries/multitenancy/TenantConfigSQLHelper.java index 0dfadb9b..1a2e00b2 100644 --- a/src/main/java/io/supertokens/storage/postgresql/queries/multitenancy/TenantConfigSQLHelper.java +++ b/src/main/java/io/supertokens/storage/postgresql/queries/multitenancy/TenantConfigSQLHelper.java @@ -16,11 +16,15 @@ package io.supertokens.storage.postgresql.queries.multitenancy; +import com.google.gson.Gson; +import com.google.gson.GsonBuilder; +import com.google.gson.JsonArray; import io.supertokens.pluginInterface.RowMapper; import io.supertokens.pluginInterface.exceptions.StorageQueryException; import io.supertokens.pluginInterface.multitenancy.*; import io.supertokens.storage.postgresql.Start; import io.supertokens.storage.postgresql.queries.utils.JsonUtils; +import io.supertokens.storage.postgresql.utils.Utils; import java.sql.Connection; import java.sql.ResultSet; @@ -36,13 +40,17 @@ public class TenantConfigSQLHelper { public static class TenantConfigRowMapper implements RowMapper { ThirdPartyConfig.Provider[] providers; + String[] firstFactors; + String[] requiredSecondaryFactors; - private TenantConfigRowMapper(ThirdPartyConfig.Provider[] providers) { + private TenantConfigRowMapper(ThirdPartyConfig.Provider[] providers, String[] firstFactors, String[] requiredSecondaryFactors) { this.providers = providers; + this.firstFactors = firstFactors; + this.requiredSecondaryFactors = requiredSecondaryFactors; } - public static TenantConfigSQLHelper.TenantConfigRowMapper getInstance(ThirdPartyConfig.Provider[] providers) { - return new TenantConfigSQLHelper.TenantConfigRowMapper(providers); + public static TenantConfigSQLHelper.TenantConfigRowMapper getInstance(ThirdPartyConfig.Provider[] providers, String[] firstFactors, String[] requiredSecondaryFactors) { + return new TenantConfigSQLHelper.TenantConfigRowMapper(providers, firstFactors, requiredSecondaryFactors); } @Override @@ -53,6 +61,8 @@ public TenantConfig map(ResultSet result) throws StorageQueryException { new EmailPasswordConfig(result.getBoolean("email_password_enabled")), new ThirdPartyConfig(result.getBoolean("third_party_enabled"), this.providers), new PasswordlessConfig(result.getBoolean("passwordless_enabled")), + firstFactors.length == 0 ? null : firstFactors, + requiredSecondaryFactors.length == 0 ? null : requiredSecondaryFactors, JsonUtils.stringToJsonObject(result.getString("core_config")) ); } catch (Exception e) { @@ -61,9 +71,10 @@ public TenantConfig map(ResultSet result) throws StorageQueryException { } } - public static TenantConfig[] selectAll(Start start, HashMap> providerMap) + public static TenantConfig[] selectAll(Start start, HashMap> providerMap, HashMap firstFactorsMap, HashMap requiredSecondaryFactorsMap) throws SQLException, StorageQueryException { - String QUERY = "SELECT connection_uri_domain, app_id, tenant_id, core_config, email_password_enabled, passwordless_enabled, third_party_enabled FROM " + String QUERY = "SELECT connection_uri_domain, app_id, tenant_id, core_config," + + " email_password_enabled, passwordless_enabled, third_party_enabled FROM " + getConfig(start).getTenantConfigsTable() + ";"; TenantConfig[] tenantConfigs = execute(start, QUERY, pst -> {}, result -> { @@ -74,7 +85,11 @@ public static TenantConfig[] selectAll(Start start, HashMap { pst.setString(1, tenantConfig.tenantIdentifier.getConnectionUriDomain()); diff --git a/src/main/java/io/supertokens/storage/postgresql/utils/Utils.java b/src/main/java/io/supertokens/storage/postgresql/utils/Utils.java index 7e662f8f..c4b78164 100644 --- a/src/main/java/io/supertokens/storage/postgresql/utils/Utils.java +++ b/src/main/java/io/supertokens/storage/postgresql/utils/Utils.java @@ -17,6 +17,8 @@ package io.supertokens.storage.postgresql.utils; +import com.google.gson.Gson; + import java.io.ByteArrayOutputStream; import java.io.PrintStream; import java.util.regex.Matcher; @@ -56,6 +58,13 @@ public static String generateCommaSeperatedQuestionMarks(int size) { return builder.toString(); } + public static String[] getStringArrayFromJsonString(String input) { + if (input == null) { + return null; + } + return new Gson().fromJson(input, String[].class); + } + public static String maskDBPassword(String log) { String regex = "(\\|db_pass\\|)(.*?)(\\|db_pass\\|)"; diff --git a/src/test/java/io/supertokens/storage/postgresql/test/AccountLinkingTests.java b/src/test/java/io/supertokens/storage/postgresql/test/AccountLinkingTests.java index 580cc049..5ab09431 100644 --- a/src/test/java/io/supertokens/storage/postgresql/test/AccountLinkingTests.java +++ b/src/test/java/io/supertokens/storage/postgresql/test/AccountLinkingTests.java @@ -88,6 +88,7 @@ public void canLinkFailsIfTryingToLinkUsersAcrossDifferentStorageLayers() throws new EmailPasswordConfig(true), new ThirdPartyConfig(true, null), new PasswordlessConfig(true), + null, null, coreConfig ) ); @@ -130,6 +131,7 @@ public void canLinkFailsIfTryingToLinkUsersAcrossDifferentStorageLayers() throws new EmailPasswordConfig(true), new ThirdPartyConfig(true, null), new PasswordlessConfig(true), + null, null, coreConfig ) ); diff --git a/src/test/java/io/supertokens/storage/postgresql/test/DbConnectionPoolTest.java b/src/test/java/io/supertokens/storage/postgresql/test/DbConnectionPoolTest.java index c89a5ac4..470e6ce0 100644 --- a/src/test/java/io/supertokens/storage/postgresql/test/DbConnectionPoolTest.java +++ b/src/test/java/io/supertokens/storage/postgresql/test/DbConnectionPoolTest.java @@ -81,7 +81,7 @@ public void testActiveConnectionsWithTenants() throws Exception { new EmailPasswordConfig(true), new ThirdPartyConfig(true, null), new PasswordlessConfig(true), - config + null, null, config ), false); Thread.sleep(1000); // let the new tenant be ready @@ -96,7 +96,7 @@ public void testActiveConnectionsWithTenants() throws Exception { new EmailPasswordConfig(true), new ThirdPartyConfig(true, null), new PasswordlessConfig(true), - config + null, null, config ), false); Thread.sleep(2000); // let the new tenant be ready @@ -139,7 +139,7 @@ public void testDownTimeWhenChangingConnectionPoolSize() throws Exception { new EmailPasswordConfig(true), new ThirdPartyConfig(true, null), new PasswordlessConfig(true), - config + null, null, config ), false); Thread.sleep(5000); // let the new tenant be ready @@ -199,7 +199,7 @@ public void testDownTimeWhenChangingConnectionPoolSize() throws Exception { new EmailPasswordConfig(true), new ThirdPartyConfig(true, null), new PasswordlessConfig(true), - config + null, null, config ), false); Thread.sleep(3000); // let the new tenant be ready @@ -281,7 +281,7 @@ public void testMinimumIdleConnectionForTenants() throws Exception { new EmailPasswordConfig(true), new ThirdPartyConfig(true, null), new PasswordlessConfig(true), - config + null, null, config ), false); Thread.sleep(1000); // let the new tenant be ready @@ -297,7 +297,7 @@ public void testMinimumIdleConnectionForTenants() throws Exception { new EmailPasswordConfig(true), new ThirdPartyConfig(true, null), new PasswordlessConfig(true), - config + null, null, config ), false); Thread.sleep(2000); // let the new tenant be ready @@ -340,7 +340,7 @@ public void testIdleConnectionTimeout() throws Exception { new EmailPasswordConfig(true), new ThirdPartyConfig(true, null), new PasswordlessConfig(true), - config + null, null, config ), false); Thread.sleep(3000); // let the new tenant be ready diff --git a/src/test/java/io/supertokens/storage/postgresql/test/DeadlockTest.java b/src/test/java/io/supertokens/storage/postgresql/test/DeadlockTest.java index 3dd0241f..22fe29bb 100644 --- a/src/test/java/io/supertokens/storage/postgresql/test/DeadlockTest.java +++ b/src/test/java/io/supertokens/storage/postgresql/test/DeadlockTest.java @@ -35,7 +35,7 @@ import io.supertokens.pluginInterface.sqlStorage.SQLStorage.TransactionIsolationLevel; import io.supertokens.pluginInterface.totp.TOTPDevice; import io.supertokens.pluginInterface.totp.TOTPUsedCode; -import io.supertokens.pluginInterface.totp.exception.TotpNotEnabledException; +import io.supertokens.pluginInterface.totp.exception.UnknownTotpUserIdException; import io.supertokens.pluginInterface.totp.exception.UsedCodeAlreadyExistsException; import io.supertokens.pluginInterface.totp.sqlStorage.TOTPSQLStorage; import io.supertokens.storageLayer.StorageLayer; @@ -284,7 +284,7 @@ public void testConcurrentDeleteAndUpdate() throws Exception { // Create a device as well as a user: TOTPSQLStorage totpStorage = (TOTPSQLStorage) StorageLayer.getStorage(process.getProcess()); - TOTPDevice device = new TOTPDevice("user", "d1", "secret", 30, 1, false); + TOTPDevice device = new TOTPDevice("user", "d1", "secret", 30, 1, false, System.currentTimeMillis()); totpStorage.createDevice(new AppIdentifier(null, null), device); long now = System.currentTimeMillis(); @@ -294,7 +294,7 @@ public void testConcurrentDeleteAndUpdate() throws Exception { try { totpStorage.insertUsedCode_Transaction(con, new TenantIdentifier(null, null, null), code); totpStorage.commitTransaction(con); - } catch (TotpNotEnabledException | UsedCodeAlreadyExistsException e) { + } catch (UnknownTotpUserIdException | UsedCodeAlreadyExistsException e) { // This should not happen throw new StorageTransactionLogicException(e); } catch (TenantOrAppNotFoundException e) { @@ -446,7 +446,7 @@ public void testConcurrentDeleteAndInsert() throws Exception { // Create a device as well as a user: TOTPSQLStorage totpStorage = (TOTPSQLStorage) StorageLayer.getStorage(process.getProcess()); - TOTPDevice device = new TOTPDevice("user", "d1", "secret", 30, 1, false); + TOTPDevice device = new TOTPDevice("user", "d1", "secret", 30, 1, false, System.currentTimeMillis()); totpStorage.createDevice(new AppIdentifier(null, null), device); long now = System.currentTimeMillis(); @@ -456,7 +456,7 @@ public void testConcurrentDeleteAndInsert() throws Exception { try { totpStorage.insertUsedCode_Transaction(con, new TenantIdentifier(null, null, null), code); totpStorage.commitTransaction(con); - } catch (TotpNotEnabledException | UsedCodeAlreadyExistsException e) { + } catch (UnknownTotpUserIdException | UsedCodeAlreadyExistsException e) { // This should not happen throw new StorageTransactionLogicException(e); } catch (TenantOrAppNotFoundException e) { @@ -559,7 +559,7 @@ public void testConcurrentDeleteAndInsert() throws Exception { TOTPUsedCode code2 = new TOTPUsedCode("user", "1234", false, nextDay, now + 1); try { totpStorage.insertUsedCode_Transaction(con, new TenantIdentifier(null, null, null), code2); - } catch (TotpNotEnabledException | UsedCodeAlreadyExistsException e) { + } catch (UnknownTotpUserIdException | UsedCodeAlreadyExistsException e) { // This should not happen throw new StorageTransactionLogicException(e); } catch (TenantOrAppNotFoundException e) { @@ -574,7 +574,7 @@ public void testConcurrentDeleteAndInsert() throws Exception { } catch (StorageTransactionLogicException e) { Exception e2 = e.actualException; - if (e2 instanceof TotpNotEnabledException) { + if (e2 instanceof UnknownTotpUserIdException) { t2Failed.set(true); } } catch (StorageQueryException e) { diff --git a/src/test/java/io/supertokens/storage/postgresql/test/LoggingTest.java b/src/test/java/io/supertokens/storage/postgresql/test/LoggingTest.java index 1e027162..591a4ac0 100644 --- a/src/test/java/io/supertokens/storage/postgresql/test/LoggingTest.java +++ b/src/test/java/io/supertokens/storage/postgresql/test/LoggingTest.java @@ -287,6 +287,7 @@ public void confirmHikariLoggerClosedOnlyWhenProcessEnds() throws Exception { new EmailPasswordConfig(true), new ThirdPartyConfig(true, null), new PasswordlessConfig(true), + null, null, config ), false); @@ -518,8 +519,8 @@ public void testDBPasswordIsNotLoggedWhenTenantIsCreated() throws Exception { new EmailPasswordConfig(true), new ThirdPartyConfig(true, null), new PasswordlessConfig(true), - config - )); + null, null, config + )); process.kill(); assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STOPPED)); @@ -563,6 +564,7 @@ public void testDBPasswordIsNotLoggedWhenTenantIsCreated() throws Exception { new EmailPasswordConfig(true), new ThirdPartyConfig(true, null), new PasswordlessConfig(true), + null, null, new JsonObject())); } catch (Exception e) { diff --git a/src/test/java/io/supertokens/storage/postgresql/test/OneMillionUsersTest.java b/src/test/java/io/supertokens/storage/postgresql/test/OneMillionUsersTest.java new file mode 100644 index 00000000..408eb46d --- /dev/null +++ b/src/test/java/io/supertokens/storage/postgresql/test/OneMillionUsersTest.java @@ -0,0 +1,905 @@ +/* + * Copyright (c) 2024, VRAI Labs and/or its affiliates. All rights reserved. + * + * This software is licensed under the Apache License, Version 2.0 (the + * "License") as published by the Apache Software Foundation. + * + * You may not use this file except in compliance with the License. You may + * obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package io.supertokens.storage.postgresql.test; + +import com.google.gson.JsonArray; +import com.google.gson.JsonElement; +import com.google.gson.JsonObject; +import io.supertokens.ActiveUsers; +import io.supertokens.Main; +import io.supertokens.ProcessState; +import io.supertokens.authRecipe.AuthRecipe; +import io.supertokens.authRecipe.UserPaginationContainer; +import io.supertokens.emailpassword.EmailPassword; +import io.supertokens.emailpassword.ParsedFirebaseSCryptResponse; +import io.supertokens.featureflag.EE_FEATURES; +import io.supertokens.featureflag.FeatureFlagTestContent; +import io.supertokens.passwordless.Passwordless; +import io.supertokens.pluginInterface.authRecipe.AuthRecipeUserInfo; +import io.supertokens.pluginInterface.authRecipe.LoginMethod; +import io.supertokens.pluginInterface.authRecipe.sqlStorage.AuthRecipeSQLStorage; +import io.supertokens.pluginInterface.emailpassword.sqlStorage.EmailPasswordSQLStorage; +import io.supertokens.pluginInterface.multitenancy.AppIdentifier; +import io.supertokens.pluginInterface.multitenancy.TenantIdentifier; +import io.supertokens.pluginInterface.passwordless.sqlStorage.PasswordlessSQLStorage; +import io.supertokens.pluginInterface.thirdparty.sqlStorage.ThirdPartySQLStorage; +import io.supertokens.session.Session; +import io.supertokens.session.info.SessionInformationHolder; +import io.supertokens.storage.postgresql.test.httpRequest.HttpRequestForTesting; +import io.supertokens.storageLayer.StorageLayer; +import io.supertokens.thirdparty.ThirdParty; +import io.supertokens.useridmapping.UserIdMapping; +import io.supertokens.usermetadata.UserMetadata; +import io.supertokens.userroles.UserRoles; +import io.supertokens.utils.SemVer; +import org.junit.AfterClass; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TestRule; + +import java.util.*; +import java.util.concurrent.*; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicLong; +import java.util.function.Supplier; + +import static org.junit.Assert.*; + +public class OneMillionUsersTest { + @Rule + public TestRule watchman = Utils.getOnFailure(); + + @AfterClass + public static void afterTesting() { + Utils.afterTesting(); + } + + @Before + public void beforeEach() { + Utils.reset(); + } + + static int TOTAL_USERS = 1000000; + static int NUM_THREADS = 16; + + Object lock = new Object(); + Set allUserIds = new HashSet<>(); + Set allPrimaryUserIds = new HashSet<>(); + Map userIdMappings = new HashMap<>(); + Map primaryUserIdMappings = new HashMap<>(); + + private void createEmailPasswordUsers(Main main) throws Exception { + System.out.println("Creating emailpassword users..."); + + int firebaseMemCost = 14; + int firebaseRounds = 8; + String firebaseSaltSeparator = "Bw=="; + + String salt = "/cj0jC1br5o4+w=="; + String passwordHash = "qZM035es5AXYqavsKD6/rhtxg7t5PhcyRgv5blc3doYbChX8keMfQLq1ra96O2Pf2TP/eZrR5xtPCYN6mX3ESA" + + "=="; + String combinedPasswordHash = "$" + ParsedFirebaseSCryptResponse.FIREBASE_SCRYPT_PREFIX + "$" + passwordHash + + "$" + salt + "$m=" + firebaseMemCost + "$r=" + firebaseRounds + "$s=" + firebaseSaltSeparator; + + ExecutorService es = Executors.newFixedThreadPool(NUM_THREADS); + + EmailPasswordSQLStorage storage = (EmailPasswordSQLStorage) StorageLayer.getBaseStorage(main); + + for (int i = 0; i < TOTAL_USERS / 4; i++) { + int finalI = i; + es.execute(() -> { + try { + String userId = io.supertokens.utils.Utils.getUUID(); + long timeJoined = System.currentTimeMillis(); + + storage.signUp(TenantIdentifier.BASE_TENANT, userId, "eptest" + finalI + "@example.com", combinedPasswordHash, + timeJoined); + synchronized (lock) { + allUserIds.add(userId); + } + } catch (Exception e) { + throw new RuntimeException(e); + } + if (finalI % 10000 == 9999) { + System.out.println("Created " + ((finalI +1)) + " users"); + } + }); + } + + es.shutdown(); + es.awaitTermination(10, TimeUnit.MINUTES); + } + + private void createPasswordlessUsersWithEmail(Main main) throws Exception { + System.out.println("Creating passwordless (email) users..."); + + ExecutorService es = Executors.newFixedThreadPool(NUM_THREADS); + PasswordlessSQLStorage storage = (PasswordlessSQLStorage) StorageLayer.getBaseStorage(main); + + for (int i = 0; i < TOTAL_USERS / 4; i++) { + int finalI = i; + es.execute(() -> { + String userId = io.supertokens.utils.Utils.getUUID(); + long timeJoined = System.currentTimeMillis(); + try { + storage.createUser(TenantIdentifier.BASE_TENANT, userId, "pltest" + finalI + "@example.com", null, timeJoined); + synchronized (lock) { + allUserIds.add(userId); + } + } catch (Exception e) { + throw new RuntimeException(e); + } + + if (finalI % 10000 == 9999) { + System.out.println("Created " + ((finalI +1)) + " users"); + } + }); + } + + es.shutdown(); + es.awaitTermination(10, TimeUnit.MINUTES); + } + + private void createPasswordlessUsersWithPhone(Main main) throws Exception { + System.out.println("Creating passwordless (phone) users..."); + + ExecutorService es = Executors.newFixedThreadPool(NUM_THREADS); + PasswordlessSQLStorage storage = (PasswordlessSQLStorage) StorageLayer.getBaseStorage(main); + + for (int i = 0; i < TOTAL_USERS / 4; i++) { + int finalI = i; + es.execute(() -> { + String userId = io.supertokens.utils.Utils.getUUID(); + long timeJoined = System.currentTimeMillis(); + try { + storage.createUser(TenantIdentifier.BASE_TENANT, userId, null, "+91987654" + finalI, timeJoined); + synchronized (lock) { + allUserIds.add(userId); + } + } catch (Exception e) { + throw new RuntimeException(e); + } + + if (finalI % 10000 == 9999) { + System.out.println("Created " + ((finalI +1)) + " users"); + } + }); + } + + es.shutdown(); + es.awaitTermination(10, TimeUnit.MINUTES); + } + + private void createThirdpartyUsers(Main main) throws Exception { + System.out.println("Creating thirdparty users..."); + + ExecutorService es = Executors.newFixedThreadPool(NUM_THREADS); + ThirdPartySQLStorage storage = (ThirdPartySQLStorage) StorageLayer.getBaseStorage(main); + + for (int i = 0; i < TOTAL_USERS / 4; i++) { + int finalI = i; + es.execute(() -> { + String userId = io.supertokens.utils.Utils.getUUID(); + long timeJoined = System.currentTimeMillis(); + + try { + storage.signUp(TenantIdentifier.BASE_TENANT, userId, "tptest" + finalI + "@example.com", new LoginMethod.ThirdParty("google", "googleid" + finalI), timeJoined ); + synchronized (lock) { + allUserIds.add(userId); + } + } catch (Exception e) { + throw new RuntimeException(e); + } + + if (finalI % 10000 == 9999) { + System.out.println("Created " + (finalI +1) + " users"); + } + }); + } + + es.shutdown(); + es.awaitTermination(10, TimeUnit.MINUTES); + } + + private void createOneMillionUsers(Main main) throws Exception { + Thread.sleep(5000); + + createEmailPasswordUsers(main); + createPasswordlessUsersWithEmail(main); + createPasswordlessUsersWithPhone(main); + createThirdpartyUsers(main); + } + + private void createUserIdMappings(Main main) throws Exception { + System.out.println("Creating user id mappings..."); + + ExecutorService es = Executors.newFixedThreadPool(NUM_THREADS); + AtomicLong usersUpdated = new AtomicLong(0); + + for (String userId : allUserIds) { + es.execute(() -> { + String extUserId = "ext" + UUID.randomUUID().toString(); + try { + UserIdMapping.createUserIdMapping(main, userId, extUserId, null, false); + synchronized (lock) { + userIdMappings.put(userId, extUserId); + } + } catch (Exception e) { + throw new RuntimeException(e); + } + + long count = usersUpdated.incrementAndGet(); + if (count % 10000 == 9999) { + System.out.println("Updated " + (count) + " users"); + } + }); + } + + es.shutdown(); + es.awaitTermination(10, TimeUnit.MINUTES); + } + + private void createUserData(Main main) throws Exception { + System.out.println("Creating user data..."); + + ExecutorService es = Executors.newFixedThreadPool(NUM_THREADS / 2); + + for (String userId : allPrimaryUserIds) { + es.execute(() -> { + Random random = new Random(); + + // User Metadata + JsonObject metadata = new JsonObject(); + metadata.addProperty("random", random.nextDouble()); + + try { + UserMetadata.updateUserMetadata(main, userIdMappings.get(userId), metadata); + + // User Roles + if (random.nextBoolean()) { + UserRoles.addRoleToUser(main, userIdMappings.get(userId), "admin"); + } else { + UserRoles.addRoleToUser(main, userIdMappings.get(userId), "user"); + } + } catch (Exception e) { + throw new RuntimeException(e); + } + }); + } + + es.shutdown(); + es.awaitTermination(1, TimeUnit.MINUTES); + } + + private void doAccountLinking(Main main) throws Exception { + Set userIds = new HashSet<>(); + userIds.addAll(allUserIds); + + assertEquals(TOTAL_USERS, userIds.size()); + + AtomicLong accountsLinked = new AtomicLong(0); + + ExecutorService es = Executors.newFixedThreadPool(NUM_THREADS); + + while (userIds.size() > 0) { + int numberOfItemsToPick = Math.min(new Random().nextInt(4) + 1, userIds.size()); + String[] userIdsArray = new String[numberOfItemsToPick]; + + Iterator iterator = userIds.iterator(); + for (int i = 0; i < numberOfItemsToPick; i++) { + userIdsArray[i] = iterator.next(); + iterator.remove(); + } + + AuthRecipeSQLStorage storage = (AuthRecipeSQLStorage) StorageLayer.getBaseStorage(main); + + es.execute(() -> { + try { + storage.startTransaction(con -> { + storage.makePrimaryUser_Transaction(new AppIdentifier(null, null), con, userIdsArray[0]); + storage.commitTransaction(con); + return null; + }); + } catch (Exception e) { + throw new RuntimeException(e); + } + + try { + for (int i = 1; i < userIdsArray.length; i++) { + int finalI = i; + storage.startTransaction(con -> { + storage.linkAccounts_Transaction(new AppIdentifier(null, null), con, userIdsArray[finalI], + userIdsArray[0]); + storage.commitTransaction(con); + return null; + }); + } + } catch (Exception e) { + throw new RuntimeException(e); + } + + synchronized (lock) { + allPrimaryUserIds.add(userIdsArray[0]); + for (String userId : userIdsArray) { + primaryUserIdMappings.put(userId, userIdsArray[0]); + } + } + + long total = accountsLinked.addAndGet(userIdsArray.length); + if (total % 10000 > 9996) { + System.out.println("Linked " + (accountsLinked) + " users"); + } + }); + } + + es.shutdown(); + es.awaitTermination(10, TimeUnit.MINUTES); + } + + private static String accessToken; + private static String sessionUserId; + + private void createSessions(Main main) throws Exception { + System.out.println("Creating sessions..."); + + ExecutorService es = Executors.newFixedThreadPool(NUM_THREADS); + + for (String userId : allUserIds) { + String finalUserId = userId; + es.execute(() -> { + try { + SessionInformationHolder session = Session.createNewSession(main, + userIdMappings.get(finalUserId), new JsonObject(), new JsonObject()); + + if (new Random().nextFloat() < 0.05) { + synchronized (lock) { + accessToken = session.accessToken.token; + sessionUserId = userIdMappings.get(primaryUserIdMappings.get(finalUserId)); + } + } + + } catch (Exception e) { + throw new RuntimeException(e); + } + }); + } + + es.shutdown(); + es.awaitTermination(10, TimeUnit.MINUTES); + } + + @Test + public void testCreatingOneMillionUsers() throws Exception { +// if (System.getenv("ONE_MILLION_USERS_TEST") == null) { +// return; +// } + + String[] args = {"../"}; + TestingProcessManager.TestingProcess process = TestingProcessManager.start(args, false); + Utils.setValueInConfig("firebase_password_hashing_signer_key", + "gRhC3eDeQOdyEn4bMd9c6kxguWVmcIVq/SKa0JDPFeM6TcEevkaW56sIWfx88OHbJKnCXdWscZx0l2WbCJ1wbg=="); + Utils.setValueInConfig("postgresql_connection_pool_size", "500"); + + FeatureFlagTestContent.getInstance(process.getProcess()) + .setKeyValue(FeatureFlagTestContent.ENABLED_FEATURES, new EE_FEATURES[]{ + EE_FEATURES.ACCOUNT_LINKING, EE_FEATURES.MULTI_TENANCY}); + process.startProcess(); + assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STARTED)); + + AtomicBoolean memoryCheckRunning = new AtomicBoolean(true); + AtomicLong maxMemory = new AtomicLong(0); + + { + long st = System.currentTimeMillis(); + createOneMillionUsers(process.getProcess()); + long en = System.currentTimeMillis(); + System.out.println("Time taken to create " + TOTAL_USERS + " users: " + ((en - st) / 1000) + " sec"); + assertEquals(TOTAL_USERS, AuthRecipe.getUsersCount(process.getProcess(), null)); + } + + { + long st = System.currentTimeMillis(); + doAccountLinking(process.getProcess()); + long en = System.currentTimeMillis(); + System.out.println("Time taken to link accounts: " + ((en - st) / 1000) + " sec"); + } + + { + long st = System.currentTimeMillis(); + createUserIdMappings(process.getProcess()); + long en = System.currentTimeMillis(); + System.out.println("Time taken to create user id mappings: " + ((en - st) / 1000) + " sec"); + } + + { + UserRoles.createNewRoleOrModifyItsPermissions(process.getProcess(), "admin", new String[]{"p1"}); + UserRoles.createNewRoleOrModifyItsPermissions(process.getProcess(), "user", new String[]{"p2"}); + long st = System.currentTimeMillis(); + createUserData(process.getProcess()); + long en = System.currentTimeMillis(); + System.out.println("Time taken to create user data: " + ((en - st) / 1000) + " sec"); + } + + { + long st = System.currentTimeMillis(); + createSessions(process.getProcess()); + long en = System.currentTimeMillis(); + System.out.println("Time taken to create sessions: " + ((en - st) / 1000) + " sec"); + } + + sanityCheckAPIs(process.getProcess()); + allUserIds.clear(); + allPrimaryUserIds.clear(); + userIdMappings.clear(); + primaryUserIdMappings.clear(); + + process.kill(false); + + Runtime.getRuntime().gc(); + System.gc(); + System.runFinalization(); + Thread.sleep(10000); + + process = TestingProcessManager.start(args, false); + Utils.setValueInConfig("firebase_password_hashing_signer_key", + "gRhC3eDeQOdyEn4bMd9c6kxguWVmcIVq/SKa0JDPFeM6TcEevkaW56sIWfx88OHbJKnCXdWscZx0l2WbCJ1wbg=="); + Utils.setValueInConfig("postgresql_connection_pool_size", "500"); + + FeatureFlagTestContent.getInstance(process.getProcess()) + .setKeyValue(FeatureFlagTestContent.ENABLED_FEATURES, new EE_FEATURES[]{ + EE_FEATURES.ACCOUNT_LINKING, EE_FEATURES.MULTI_TENANCY}); + process.startProcess(); + assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STARTED)); + + process.kill(false); + process = TestingProcessManager.start(args, false); + Utils.setValueInConfig("firebase_password_hashing_signer_key", + "gRhC3eDeQOdyEn4bMd9c6kxguWVmcIVq/SKa0JDPFeM6TcEevkaW56sIWfx88OHbJKnCXdWscZx0l2WbCJ1wbg=="); + Utils.setValueInConfig("postgresql_connection_pool_size", "500"); + + FeatureFlagTestContent.getInstance(process.getProcess()) + .setKeyValue(FeatureFlagTestContent.ENABLED_FEATURES, new EE_FEATURES[]{ + EE_FEATURES.ACCOUNT_LINKING, EE_FEATURES.MULTI_TENANCY}); + process.startProcess(); + assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STARTED)); + + Thread memoryChecker = new Thread(() -> { + while (memoryCheckRunning.get()) { + Runtime rt = Runtime.getRuntime(); + long total_mem = rt.totalMemory(); + long free_mem = rt.freeMemory(); + long used_mem = total_mem - free_mem; + + if (used_mem > maxMemory.get()) { + maxMemory.set(used_mem); + } + + try { + Thread.sleep(1000); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + } + }); + memoryChecker.start(); + + measureOperations(process.getProcess()); + + memoryCheckRunning.set(false); + memoryChecker.join(); + + System.out.println("Max memory used: " + (maxMemory.get() / (1024 * 1024)) + " MB"); + assert maxMemory.get() < 256 * 1024 * 1024; // must be less than 256 mb + + process.kill(); + assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STOPPED)); + } + + private void sanityCheckAPIs(Main main) throws Exception { + { // Email password sign in + JsonObject responseBody = new JsonObject(); + responseBody.addProperty("email", "eptest10@example.com"); + responseBody.addProperty("password", "testPass123"); + + Thread.sleep(1); // add a small delay to ensure a unique timestamp + long beforeSignIn = System.currentTimeMillis(); + + JsonObject signInResponse = HttpRequestForTesting.sendJsonPOSTRequest(main, "", + "http://localhost:3567/recipe/signin", responseBody, 1000, 1000, null, SemVer.v4_0.get(), + "emailpassword"); + + assertEquals(signInResponse.get("status").getAsString(), "OK"); + assertEquals(signInResponse.entrySet().size(), 3); + + JsonObject jsonUser = signInResponse.get("user").getAsJsonObject(); + JsonArray emails = jsonUser.get("emails").getAsJsonArray(); + boolean found = false; + + for (JsonElement elem : emails) { + if (elem.getAsString().equals("eptest10@example.com")) { + found = true; + break; + } + } + + assertTrue(found); + + int activeUsers = ActiveUsers.countUsersActiveSince(main, beforeSignIn); + assert (activeUsers == 1); + } + + { // passwordless sign in + long startTs = System.currentTimeMillis(); + + String email = "pltest10@example.com"; + Passwordless.CreateCodeResponse createResp = Passwordless.createCode(main, email, null, null, null); + + JsonObject consumeCodeRequestBody = new JsonObject(); + consumeCodeRequestBody.addProperty("deviceId", createResp.deviceId); + consumeCodeRequestBody.addProperty("preAuthSessionId", createResp.deviceIdHash); + consumeCodeRequestBody.addProperty("userInputCode", createResp.userInputCode); + + JsonObject response = HttpRequestForTesting.sendJsonPOSTRequest(main, "", + "http://localhost:3567/recipe/signinup/code/consume", consumeCodeRequestBody, 1000, 1000, null, + SemVer.v5_0.get(), "passwordless"); + + assertEquals("OK", response.get("status").getAsString()); + assertEquals(false, response.get("createdNewUser").getAsBoolean()); + assert (response.has("user")); + + JsonObject jsonUser = response.get("user").getAsJsonObject(); + JsonArray emails = jsonUser.get("emails").getAsJsonArray(); + boolean found = false; + + for (JsonElement elem : emails) { + if (elem.getAsString().equals("pltest10@example.com")) { + found = true; + break; + } + } + + assertTrue(found); + + int activeUsers = ActiveUsers.countUsersActiveSince(main, startTs); + assert (activeUsers == 1); + } + + { // thirdparty sign in + long startTs = System.currentTimeMillis(); + JsonObject emailObject = new JsonObject(); + emailObject.addProperty("id", "tptest10@example.com"); + emailObject.addProperty("isVerified", true); + + JsonObject signUpRequestBody = new JsonObject(); + signUpRequestBody.addProperty("thirdPartyId", "google"); + signUpRequestBody.addProperty("thirdPartyUserId", "googleid10"); + signUpRequestBody.add("email", emailObject); + + JsonObject response = HttpRequestForTesting.sendJsonPOSTRequest(main, "", + "http://localhost:3567/recipe/signinup", signUpRequestBody, 1000, 1000, null, + SemVer.v4_0.get(), "thirdparty"); + + assertEquals("OK", response.get("status").getAsString()); + assertEquals(false, response.get("createdNewUser").getAsBoolean()); + assert (response.has("user")); + + JsonObject jsonUser = response.get("user").getAsJsonObject(); + JsonArray emails = jsonUser.get("emails").getAsJsonArray(); + boolean found = false; + + for (JsonElement elem : emails) { + if (elem.getAsString().equals("tptest10@example.com")) { + found = true; + break; + } + } + + assertTrue(found); + + int activeUsers = ActiveUsers.countUsersActiveSince(main, startTs); + assert (activeUsers == 1); + } + + { // session for user + JsonObject request = new JsonObject(); + request.addProperty("accessToken", accessToken); + request.addProperty("doAntiCsrfCheck", false); + request.addProperty("enableAntiCsrf", false); + request.addProperty("checkDatabase", false); + JsonObject response = HttpRequestForTesting.sendJsonPOSTRequest(main, "", + "http://localhost:3567/recipe/session/verify", request, 1000, 1000, null, + SemVer.v5_0.get(), "session"); + assertEquals("OK", response.get("status").getAsString()); + assertEquals(sessionUserId, response.get("session").getAsJsonObject().get("userId").getAsString()); + } + + { // check user roles + JsonObject responseBody = new JsonObject(); + responseBody.addProperty("email", "eptest10@example.com"); + responseBody.addProperty("password", "testPass123"); + + Thread.sleep(1); // add a small delay to ensure a unique timestamp + JsonObject signInResponse = HttpRequestForTesting.sendJsonPOSTRequest(main, "", + "http://localhost:3567/recipe/signin", responseBody, 1000, 1000, null, SemVer.v4_0.get(), + "emailpassword"); + + HashMap QUERY_PARAMS = new HashMap<>(); + QUERY_PARAMS.put("userId", signInResponse.get("user").getAsJsonObject().get("id").getAsString()); + JsonObject response = HttpRequestForTesting.sendGETRequest(main, "", + "http://localhost:3567/recipe/user/roles", QUERY_PARAMS, 1000, 1000, null, + SemVer.v2_14.get(), "userroles"); + + assertEquals(2, response.entrySet().size()); + assertEquals("OK", response.get("status").getAsString()); + + JsonArray userRolesArr = response.getAsJsonArray("roles"); + assertEquals(1, userRolesArr.size()); + assertTrue( + userRolesArr.get(0).getAsString().equals("admin") || userRolesArr.get(0).getAsString().equals("user") + ); + } + + { // check user metadata + HashMap QueryParams = new HashMap(); + QueryParams.put("userId", sessionUserId); + JsonObject resp = HttpRequestForTesting.sendGETRequest(main, "", + "http://localhost:3567/recipe/user/metadata", QueryParams, 1000, 1000, null, + SemVer.v2_13.get(), "usermetadata"); + + assertEquals(2, resp.entrySet().size()); + assertEquals("OK", resp.get("status").getAsString()); + assert (resp.has("metadata")); + JsonObject respMetadata = resp.getAsJsonObject("metadata"); + assertEquals(1, respMetadata.entrySet().size()); + } + } + + private void measureOperations(Main main) throws Exception { + AtomicLong errorCount = new AtomicLong(0); + { // Emailpassword sign up + System.out.println("Measure email password sign-ups"); + long time = measureTime(() -> { + ExecutorService es = Executors.newFixedThreadPool(100); + + for (int i = 0; i < 500; i++) { + int finalI = i; + es.execute(() -> { + try { + EmailPassword.signUp(main, "ep" + finalI + "@example.com", "password" + finalI); + } catch (Exception e) { + errorCount.incrementAndGet(); + throw new RuntimeException(e); + } + }); + } + + es.shutdown(); + try { + es.awaitTermination(5, TimeUnit.MINUTES); + } catch (InterruptedException e) { + errorCount.incrementAndGet(); + throw new RuntimeException(e); + } + return null; + }); + System.out.println("EP sign up " + time); + assert time < 15000; + } + { // Emailpassword sign in + System.out.println("Measure email password sign-ins"); + long time = measureTime(() -> { + ExecutorService es = Executors.newFixedThreadPool(100); + + for (int i = 0; i < 500; i++) { + int finalI = i; + es.execute(() -> { + try { + EmailPassword.signIn(main, "ep" + finalI + "@example.com", "password" + finalI); + } catch (Exception e) { + errorCount.incrementAndGet(); + throw new RuntimeException(e); + } + }); + } + + es.shutdown(); + try { + es.awaitTermination(5, TimeUnit.MINUTES); + } catch (InterruptedException e) { + errorCount.incrementAndGet(); + throw new RuntimeException(e); + } + return null; + }); + System.out.println("EP sign in " + time); + assert time < 15000; + } + { // Passwordless sign-ups + System.out.println("Measure passwordless sign-ups"); + long time = measureTime(() -> { + ExecutorService es = Executors.newFixedThreadPool(100); + for (int i = 0; i < 500; i++) { + int finalI = i; + es.execute(() -> { + try { + Passwordless.CreateCodeResponse code = Passwordless.createCode(main, + "pl" + finalI + "@example.com", null, null, null); + Passwordless.consumeCode(main, code.deviceId, code.deviceIdHash, code.userInputCode, null); + } catch (Exception e) { + errorCount.incrementAndGet(); + throw new RuntimeException(e); + } + }); + } + es.shutdown(); + try { + es.awaitTermination(5, TimeUnit.MINUTES); + } catch (InterruptedException e) { + errorCount.incrementAndGet(); + throw new RuntimeException(e); + } + return null; + }); + System.out.println("PL sign up " + time); + assert time < 5000; + } + { // Passwordless sign-ins + System.out.println("Measure passwordless sign-ins"); + long time = measureTime(() -> { + ExecutorService es = Executors.newFixedThreadPool(100); + for (int i = 0; i < 500; i++) { + int finalI = i; + es.execute(() -> { + try { + Passwordless.CreateCodeResponse code = Passwordless.createCode(main, + "pl" + finalI + "@example.com", null, null, null); + Passwordless.consumeCode(main, code.deviceId, code.deviceIdHash, code.userInputCode, null); + } catch (Exception e) { + errorCount.incrementAndGet(); + throw new RuntimeException(e); + } + }); + } + es.shutdown(); + try { + es.awaitTermination(5, TimeUnit.MINUTES); + } catch (InterruptedException e) { + errorCount.incrementAndGet(); + throw new RuntimeException(e); + } + return null; + }); + System.out.println("PL sign in " + time); + assert time < 5000; + } + { // Thirdparty sign-ups + System.out.println("Measure thirdparty sign-ups"); + long time = measureTime(() -> { + ExecutorService es = Executors.newFixedThreadPool(100); + for (int i = 0; i < 500; i++) { + int finalI = i; + es.execute(() -> { + try { + ThirdParty.signInUp(main, "twitter", "twitterid" + finalI, "twitter" + finalI + "@example.com"); + } catch (Exception e) { + errorCount.incrementAndGet(); + throw new RuntimeException(e); + } + }); + } + es.shutdown(); + try { + es.awaitTermination(5, TimeUnit.MINUTES); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + return null; + }); + System.out.println("Thirdparty sign up " + time); + assert time < 5000; + } + { // Thirdparty sign-ins + System.out.println("Measure thirdparty sign-ins"); + long time = measureTime(() -> { + ExecutorService es = Executors.newFixedThreadPool(100); + for (int i = 0; i < 500; i++) { + int finalI = i; + es.execute(() -> { + try { + ThirdParty.signInUp(main, "twitter", "twitterid" + finalI, "twitter" + finalI + "@example.com"); + } catch (Exception e) { + errorCount.incrementAndGet(); + throw new RuntimeException(e); + } + }); + } + es.shutdown(); + try { + es.awaitTermination(5, TimeUnit.MINUTES); + } catch (InterruptedException e) { + errorCount.incrementAndGet(); + throw new RuntimeException(e); + } + return null; + }); + System.out.println("Thirdparty sign in " + time); + assert time < 5000; + } + { // Measure user pagination + long time = measureTime(() -> { + try { + long count = 0; + UserPaginationContainer users = AuthRecipe.getUsers(main, 500, "ASC", null, null, null); + while (true) { + for (AuthRecipeUserInfo user : users.users) { + count += user.loginMethods.length; + } + if (users.nextPaginationToken == null) { + break; + } + users = AuthRecipe.getUsers(main, 500, "ASC", users.nextPaginationToken, null, null); + if (count >= 500) { + break; + } + } + } catch (Exception e) { + errorCount.incrementAndGet(); + throw new RuntimeException(e); + } + return null; + }); + System.out.println("User pagination " + time); + assert time < 2000; + } + { // Measure update user metadata + long time = measureTime(() -> { + try { + UserPaginationContainer users = AuthRecipe.getUsers(main, 1, "ASC", null, null, null); + UserIdMapping.populateExternalUserIdForUsers( + new AppIdentifier(null, null), + (StorageLayer.getBaseStorage(main)), + users.users); + + AuthRecipeUserInfo user = users.users[0]; + for (int i = 0; i < 500; i++) { + UserMetadata.updateUserMetadata(main, user.getSupertokensOrExternalUserId(), new JsonObject()); + } + } catch (Exception e) { + errorCount.incrementAndGet(); + throw new RuntimeException(e); + } + return null; + }); + System.out.println("Update user metadata " + time); + } + + assertEquals(0, errorCount.get()); + } + + private static long measureTime(Supplier function) { + long startTime = System.nanoTime(); + + // Call the function + function.get(); + + long endTime = System.nanoTime(); + + // Calculate elapsed time in milliseconds + return (endTime - startTime) / 1000000; // Convert to milliseconds + } +} diff --git a/src/test/java/io/supertokens/storage/postgresql/test/StorageLayerTest.java b/src/test/java/io/supertokens/storage/postgresql/test/StorageLayerTest.java index b48324bb..8f6d1699 100644 --- a/src/test/java/io/supertokens/storage/postgresql/test/StorageLayerTest.java +++ b/src/test/java/io/supertokens/storage/postgresql/test/StorageLayerTest.java @@ -16,7 +16,7 @@ import io.supertokens.pluginInterface.multitenancy.exceptions.TenantOrAppNotFoundException; import io.supertokens.pluginInterface.totp.TOTPDevice; import io.supertokens.pluginInterface.totp.TOTPUsedCode; -import io.supertokens.pluginInterface.totp.exception.TotpNotEnabledException; +import io.supertokens.pluginInterface.totp.exception.UnknownTotpUserIdException; import io.supertokens.pluginInterface.totp.exception.UsedCodeAlreadyExistsException; import io.supertokens.pluginInterface.totp.sqlStorage.TOTPSQLStorage; import io.supertokens.storageLayer.StorageLayer; @@ -54,7 +54,7 @@ public static void insertUsedCodeUtil(TOTPSQLStorage storage, TOTPUsedCode usedC storage.insertUsedCode_Transaction(con, new TenantIdentifier(null, null, null), usedCode); storage.commitTransaction(con); return null; - } catch (TotpNotEnabledException | UsedCodeAlreadyExistsException e) { + } catch (UnknownTotpUserIdException | UsedCodeAlreadyExistsException e) { throw new StorageTransactionLogicException(e); } catch (TenantOrAppNotFoundException e) { throw new IllegalStateException(e); @@ -62,7 +62,7 @@ public static void insertUsedCodeUtil(TOTPSQLStorage storage, TOTPUsedCode usedC }); } catch (StorageTransactionLogicException e) { Exception actual = e.actualException; - if (actual instanceof TotpNotEnabledException || actual instanceof UsedCodeAlreadyExistsException) { + if (actual instanceof UnknownTotpUserIdException || actual instanceof UsedCodeAlreadyExistsException) { throw actual; } else { throw e; @@ -85,7 +85,7 @@ public void totpCodeLengthTest() throws Exception { long now = System.currentTimeMillis(); long nextDay = now + 1000 * 60 * 60 * 24; // 1 day from now - TOTPDevice d1 = new TOTPDevice("user", "d1", "secret", 30, 1, false); + TOTPDevice d1 = new TOTPDevice("user", "d1", "secret", 30, 1, false, System.currentTimeMillis()); storage.createDevice(new AppIdentifier(null, null), d1); // Try code with length > 8 diff --git a/src/test/java/io/supertokens/storage/postgresql/test/SuperTokensSaaSSecretTest.java b/src/test/java/io/supertokens/storage/postgresql/test/SuperTokensSaaSSecretTest.java index 6693acb1..51673061 100644 --- a/src/test/java/io/supertokens/storage/postgresql/test/SuperTokensSaaSSecretTest.java +++ b/src/test/java/io/supertokens/storage/postgresql/test/SuperTokensSaaSSecretTest.java @@ -89,6 +89,7 @@ public void testThatTenantCannotSetDatabaseRelatedConfigIfSuperTokensSaaSSecretI Multitenancy.addNewOrUpdateAppOrTenant(process.main, new TenantConfig(new TenantIdentifier(null, null, "t1"), new EmailPasswordConfig(false), new ThirdPartyConfig(false, new ThirdPartyConfig.Provider[0]), new PasswordlessConfig(false), + null, null, j), true); fail(); } catch (BadPermissionException e) { @@ -165,6 +166,7 @@ public void testThatTenantCanSetDatabaseRelatedConfigIfSuperTokensSaaSSecretIsNo new EmailPasswordConfig(false), new ThirdPartyConfig(false, new ThirdPartyConfig.Provider[0]), new PasswordlessConfig(false), + null, null, j), false); } @@ -217,6 +219,7 @@ public void testThatTenantCannotGetDatabaseRelatedConfigIfSuperTokensSaaSSecretI new TenantConfig(new TenantIdentifier(null, null, "t" + i), new EmailPasswordConfig(false), new ThirdPartyConfig(false, new ThirdPartyConfig.Provider[0]), new PasswordlessConfig(false), + null, null, j)); { diff --git a/src/test/java/io/supertokens/storage/postgresql/test/multitenancy/StorageLayerTest.java b/src/test/java/io/supertokens/storage/postgresql/test/multitenancy/StorageLayerTest.java index 7bca0a99..77204865 100644 --- a/src/test/java/io/supertokens/storage/postgresql/test/multitenancy/StorageLayerTest.java +++ b/src/test/java/io/supertokens/storage/postgresql/test/multitenancy/StorageLayerTest.java @@ -112,6 +112,7 @@ public void mergingTenantWithBaseConfigWorks() new TenantConfig(new TenantIdentifier("abc", null, null), new EmailPasswordConfig(false), new ThirdPartyConfig(false, new ThirdPartyConfig.Provider[0]), new PasswordlessConfig(false), + null, null, tenantConfig)}; Config.loadAllTenantConfig(process.getProcess(), tenants); @@ -162,6 +163,7 @@ public void storageInstanceIsReusedAcrossTenants() new TenantConfig(new TenantIdentifier(null, "abc", null), new EmailPasswordConfig(false), new ThirdPartyConfig(false, new ThirdPartyConfig.Provider[0]), new PasswordlessConfig(false), + null, null, tenantConfig)}; Config.loadAllTenantConfig(process.getProcess(), tenants); @@ -209,14 +211,17 @@ public void storageInstanceIsReusedAcrossTenantsComplex() new TenantConfig(new TenantIdentifier(null, "abc", null), new EmailPasswordConfig(false), new ThirdPartyConfig(false, new ThirdPartyConfig.Provider[0]), new PasswordlessConfig(false), + null, null, tenantConfig), new TenantConfig(new TenantIdentifier(null, "abc", "t1"), new EmailPasswordConfig(false), new ThirdPartyConfig(false, new ThirdPartyConfig.Provider[0]), new PasswordlessConfig(false), + null, null, tenantConfig1), new TenantConfig(new TenantIdentifier(null, null, "t2"), new EmailPasswordConfig(false), new ThirdPartyConfig(false, new ThirdPartyConfig.Provider[0]), new PasswordlessConfig(false), + null, null, tenantConfig1)}; Config.loadAllTenantConfig(process.getProcess(), tenants); @@ -281,6 +286,7 @@ public void mergingTenantWithBaseConfigWithInvalidConfigThrowsErrorWorks() new TenantConfig(new TenantIdentifier("abc", null, null), new EmailPasswordConfig(false), new ThirdPartyConfig(false, new ThirdPartyConfig.Provider[0]), new PasswordlessConfig(false), + null, null, tenantConfig)}; Config.loadAllTenantConfig(process.getProcess(), tenants); @@ -314,6 +320,7 @@ public void mergingTenantWithBaseConfigWithConflictingConfigsThrowsError() new TenantConfig(new TenantIdentifier(null, "abc", null), new EmailPasswordConfig(false), new ThirdPartyConfig(false, new ThirdPartyConfig.Provider[0]), new PasswordlessConfig(false), + null, null, tenantConfig)}; Config.loadAllTenantConfig(process.getProcess(), tenants); @@ -348,6 +355,7 @@ public void mergingDifferentConnectionPoolIdTenantWithBaseConfigWithConflictingC new TenantConfig(new TenantIdentifier(null, "abc", null), new EmailPasswordConfig(false), new ThirdPartyConfig(false, new ThirdPartyConfig.Provider[0]), new PasswordlessConfig(false), + null, null, tenantConfig)}; Config.loadAllTenantConfig(process.getProcess(), tenants); @@ -383,6 +391,7 @@ public void mergingDifferentUserPoolIdTenantWithBaseConfigWithConflictingConfigs new TenantConfig(new TenantIdentifier("abc", null, null), new EmailPasswordConfig(false), new ThirdPartyConfig(false, new ThirdPartyConfig.Provider[0]), new PasswordlessConfig(false), + null, null, tenantConfig)}; Config.loadAllTenantConfig(process.getProcess(), tenants); @@ -433,6 +442,7 @@ public void newStorageIsNotCreatedWhenSameTenantIsAdded() new TenantConfig(new TenantIdentifier(null, "abc", null), new EmailPasswordConfig(false), new ThirdPartyConfig(false, new ThirdPartyConfig.Provider[0]), new PasswordlessConfig(false), + null, null, tenantConfig)}; Config.loadAllTenantConfig(process.getProcess(), tenants); @@ -480,6 +490,7 @@ public void testDifferentWaysToGetConfigBasedOnConnectionURIAndTenantId() tenants[0] = new TenantConfig(new TenantIdentifier("c1", null, null), new EmailPasswordConfig(false), new ThirdPartyConfig(false, new ThirdPartyConfig.Provider[0]), new PasswordlessConfig(false), + null, null, tenantConfig); } @@ -491,6 +502,7 @@ public void testDifferentWaysToGetConfigBasedOnConnectionURIAndTenantId() tenants[1] = new TenantConfig(new TenantIdentifier("c1", null, "t1"), new EmailPasswordConfig(false), new ThirdPartyConfig(false, new ThirdPartyConfig.Provider[0]), new PasswordlessConfig(false), + null, null, tenantConfig); } @@ -500,6 +512,7 @@ public void testDifferentWaysToGetConfigBasedOnConnectionURIAndTenantId() tenants[2] = new TenantConfig(new TenantIdentifier(null, null, "t2"), new EmailPasswordConfig(false), new ThirdPartyConfig(false, new ThirdPartyConfig.Provider[0]), new PasswordlessConfig(false), + null, null, tenantConfig); } @@ -509,6 +522,7 @@ public void testDifferentWaysToGetConfigBasedOnConnectionURIAndTenantId() tenants[3] = new TenantConfig(new TenantIdentifier(null, null, "t1"), new EmailPasswordConfig(false), new ThirdPartyConfig(false, new ThirdPartyConfig.Provider[0]), new PasswordlessConfig(false), + null, null, tenantConfig); } @@ -572,6 +586,7 @@ public void differentUserPoolCreatedBasedOnSchemaInConnectionUri() new TenantConfig(new TenantIdentifier("abc", null, null), new EmailPasswordConfig(false), new ThirdPartyConfig(false, new ThirdPartyConfig.Provider[0]), new PasswordlessConfig(false), + null, null, tenantConfig)}; Config.loadAllTenantConfig(process.getProcess(), tenants); @@ -614,6 +629,7 @@ public void multipleTenantsSameUserPoolAndConnectionPoolShouldWork() new TenantConfig(new TenantIdentifier(null, "abc", null), new EmailPasswordConfig(false), new ThirdPartyConfig(false, new ThirdPartyConfig.Provider[0]), new PasswordlessConfig(false), + null, null, tenantConfig)}; Config.loadAllTenantConfig(process.getProcess(), tenants); @@ -649,6 +665,7 @@ public void multipleTenantsSameUserPoolAndDifferentConnectionPoolShouldWork() new TenantConfig(new TenantIdentifier(null, "abc", null), new EmailPasswordConfig(false), new ThirdPartyConfig(false, new ThirdPartyConfig.Provider[0]), new PasswordlessConfig(false), + null, null, tenantConfig)}; Config.loadAllTenantConfig(process.getProcess(), tenants); @@ -688,6 +705,7 @@ public void testCreating50StorageLayersUsage() new EmailPasswordConfig(false), new ThirdPartyConfig(false, new ThirdPartyConfig.Provider[0]), new PasswordlessConfig(false), + null, null, config); try { Multitenancy.addNewOrUpdateAppOrTenant(process.getProcess(), new TenantIdentifier(null, null, null), @@ -741,6 +759,7 @@ public void testCantCreateTenantWithUnknownDb() new EmailPasswordConfig(false), new ThirdPartyConfig(false, new ThirdPartyConfig.Provider[0]), new PasswordlessConfig(false), + null, null, tenantConfigJson); try { @@ -782,6 +801,7 @@ public void testTenantCreationAndThenDbDownDbThrowsErrorInRecipesAndDoesntAffect new EmailPasswordConfig(true), new ThirdPartyConfig(false, new ThirdPartyConfig.Provider[0]), new PasswordlessConfig(false), + null, null, tenantConfigJson); StorageLayer.getMultitenancyStorage(process.getProcess()).createTenant(tenantConfig); @@ -862,6 +882,7 @@ public void testBadPortWithNewTenantShouldNotCauseItToWaitInput() throws Excepti new EmailPasswordConfig(false), new ThirdPartyConfig(false, new ThirdPartyConfig.Provider[0]), new PasswordlessConfig(false), + null, null, tenantConfigJson); try { diff --git a/src/test/java/io/supertokens/storage/postgresql/test/multitenancy/TestUserPoolIdChangeBehaviour.java b/src/test/java/io/supertokens/storage/postgresql/test/multitenancy/TestUserPoolIdChangeBehaviour.java index bc5a791e..758e749a 100644 --- a/src/test/java/io/supertokens/storage/postgresql/test/multitenancy/TestUserPoolIdChangeBehaviour.java +++ b/src/test/java/io/supertokens/storage/postgresql/test/multitenancy/TestUserPoolIdChangeBehaviour.java @@ -84,6 +84,7 @@ public void testUsersWorkAfterUserPoolIdChanges() throws Exception { new EmailPasswordConfig(true), new ThirdPartyConfig(true, null), new PasswordlessConfig(true), + null, null, coreConfig ), false); @@ -101,6 +102,7 @@ public void testUsersWorkAfterUserPoolIdChanges() throws Exception { new EmailPasswordConfig(true), new ThirdPartyConfig(true, null), new PasswordlessConfig(true), + null, null, coreConfig ), false); @@ -129,6 +131,7 @@ public void testUsersWorkAfterUserPoolIdChangesAndServerRestart() throws Excepti new EmailPasswordConfig(true), new ThirdPartyConfig(true, null), new PasswordlessConfig(true), + null, null, coreConfig ), false); @@ -146,6 +149,7 @@ public void testUsersWorkAfterUserPoolIdChangesAndServerRestart() throws Excepti new EmailPasswordConfig(true), new ThirdPartyConfig(true, null), new PasswordlessConfig(true), + null, null, coreConfig ), false);