From 0b11cd08c372d457b469c097de473f2c1ddc8f4a Mon Sep 17 00:00:00 2001 From: Tamas Soltesz Date: Thu, 19 Dec 2024 09:43:09 +0100 Subject: [PATCH] feat: multithreaded bulk import (#235) * feat: Add BulkImport APIs and cron * fix: PR changes * fix: PR changes * fix: PR changes * fix: PR changes * fix: PR changes * fix: PR changes * fix: Update version and changelog * fix: PR changes * fix: PR changes * fix: PR changes * fix: removing restriction of connection pool size for bulk import * fix: actually closing the connection * fix: add bulk import retry logic for postgres too * fix: fix failing tests * chore: current state save * fix: fixing merge error with changelog * feat: bulk inserting the bulk migration data * fix: fixes and error handling changes * fix: fixing tests * chore: changelog and build version update * fix: handling app/tenant not found * fix: review fix --------- Co-authored-by: Ankit Tiwari --- CHANGELOG.md | 27 + build.gradle | 2 +- .../postgresql/BulkImportProxyConnection.java | 329 ++++++++++ .../postgresql/BulkImportProxyStorage.java | 114 ++++ .../storage/postgresql/ConnectionPool.java | 13 +- .../postgresql/QueryExecutorTemplate.java | 14 +- .../supertokens/storage/postgresql/Start.java | 583 +++++++++++++++++- .../postgresql/config/PostgreSQLConfig.java | 4 + .../queries/ActiveUsersQueries.java | 30 +- .../postgresql/queries/BulkImportQueries.java | 347 +++++++++++ .../queries/EmailPasswordQueries.java | 132 +++- .../queries/EmailVerificationQueries.java | 128 ++++ .../postgresql/queries/GeneralQueries.java | 226 ++++++- .../queries/PasswordlessQueries.java | 174 ++++++ .../postgresql/queries/SessionQueries.java | 27 + .../postgresql/queries/TOTPQueries.java | 94 ++- .../postgresql/queries/ThirdPartyQueries.java | 196 +++++- .../queries/UserIdMappingQueries.java | 71 +++ .../queries/UserMetadataQueries.java | 59 ++ .../postgresql/queries/UserRolesQueries.java | 75 ++- .../postgresql/test/OneMillionUsersTest.java | 249 +++++++- 21 files changed, 2843 insertions(+), 51 deletions(-) create mode 100644 src/main/java/io/supertokens/storage/postgresql/BulkImportProxyConnection.java create mode 100644 src/main/java/io/supertokens/storage/postgresql/BulkImportProxyStorage.java create mode 100644 src/main/java/io/supertokens/storage/postgresql/queries/BulkImportQueries.java diff --git a/CHANGELOG.md b/CHANGELOG.md index af622075..1d781000 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,33 @@ to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). ## [Unreleased] +## [7.3.0] + +- Adds tables and queries for Bulk Import + +### Migration + +```sql +"CREATE TABLE IF NOT EXISTS bulk_import_users ( + id CHAR(36), + app_id VARCHAR(64) NOT NULL DEFAULT 'public', + primary_user_id VARCHAR(36), + raw_data TEXT NOT NULL, + status VARCHAR(128) DEFAULT 'NEW', + error_msg TEXT, + created_at BIGINT NOT NULL, + updated_at BIGINT NOT NULL, + CONSTRAINT bulk_import_users_pkey PRIMARY KEY(app_id, id), + CONSTRAINT bulk_import_users__app_id_fkey FOREIGN KEY(app_id) REFERENCES apps(app_id) ON DELETE CASCADE +); + +CREATE INDEX IF NOT EXISTS bulk_import_users_status_updated_at_index ON bulk_import_users (app_id, status, updated_at); + +CREATE INDEX IF NOT EXISTS bulk_import_users_pagination_index1 ON bulk_import_users (app_id, status, created_at DESC, + id DESC); + +CREATE INDEX IF NOT EXISTS bulk_import_users_pagination_index2 ON bulk_import_users (app_id, created_at DESC, id DESC); +``` ## [7.2.0] - 2024-10-03 - Compatible with plugin interface version 6.3 diff --git a/build.gradle b/build.gradle index 553a2c88..112bd9a1 100644 --- a/build.gradle +++ b/build.gradle @@ -2,7 +2,7 @@ plugins { id 'java-library' } -version = "7.2.0" +version = "7.3.0" repositories { mavenCentral() diff --git a/src/main/java/io/supertokens/storage/postgresql/BulkImportProxyConnection.java b/src/main/java/io/supertokens/storage/postgresql/BulkImportProxyConnection.java new file mode 100644 index 00000000..29ab11aa --- /dev/null +++ b/src/main/java/io/supertokens/storage/postgresql/BulkImportProxyConnection.java @@ -0,0 +1,329 @@ +/* + * Copyright (c) 2024, VRAI Labs and/or its affiliates. All rights reserved. + * + * This software is licensed under the Apache License, Version 2.0 (the + * "License") as published by the Apache Software Foundation. + * + * You may not use this file except in compliance with the License. You may + * obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package io.supertokens.storage.postgresql; + +import java.sql.*; +import java.util.Map; +import java.util.Properties; +import java.util.concurrent.Executor; + +/** +* BulkImportProxyConnection is a class implementing the Connection interface, serving as a Connection instance in the bulk import user cronjob. +* This cron extensively utilizes existing queries to import users, all of which internally operate within transactions and those query sometimes +* call the commit/rollback method on the connection. +* +* For the purpose of bulkimport cronjob, we aim to employ a single connection for all queries and rollback any operations in case of query failures. +* To achieve this, we use our own proxy Connection instance and override the commit/rollback/close methods to do nothing. +*/ + +public class BulkImportProxyConnection implements Connection { + private Connection con = null; + + public BulkImportProxyConnection(Connection con) { + this.con = con; + } + + @Override + public void close() throws SQLException { + //this.con.close(); + //we don't want to close here because we are trying to reuse existing code but also using the same connection + //for bulk importing + } + + @Override + public void commit() throws SQLException { + //this.con.commit(); + } + + @Override + public void rollback() throws SQLException { + //this.con.rollback(); + } + + public void closeForBulkImportProxyStorage() throws SQLException { + this.con.close(); + } + + public void commitForBulkImportProxyStorage() throws SQLException { + this.con.commit(); + } + + public void rollbackForBulkImportProxyStorage() throws SQLException { + this.con.rollback(); + } + + /* Following methods are unchaged */ + + @Override + public Statement createStatement() throws SQLException { + return this.con.createStatement(); + } + + @Override + public PreparedStatement prepareStatement(String sql) throws SQLException { + return this.con.prepareStatement(sql); + } + + @Override + public CallableStatement prepareCall(String sql) throws SQLException { + return this.con.prepareCall(sql); + } + + @Override + public String nativeSQL(String sql) throws SQLException { + return this.con.nativeSQL(sql); + } + + @Override + public void setAutoCommit(boolean autoCommit) throws SQLException { + this.con.setAutoCommit(autoCommit); + } + + @Override + public boolean getAutoCommit() throws SQLException { + return this.con.getAutoCommit(); + } + + @Override + public boolean isClosed() throws SQLException { + return this.con.isClosed(); + } + + @Override + public DatabaseMetaData getMetaData() throws SQLException { + return this.con.getMetaData(); + } + + @Override + public void setReadOnly(boolean readOnly) throws SQLException { + this.con.setReadOnly(readOnly); + } + + @Override + public boolean isReadOnly() throws SQLException { + return this.con.isReadOnly(); + } + + @Override + public void setCatalog(String catalog) throws SQLException { + this.con.setCatalog(catalog); + } + + @Override + public String getCatalog() throws SQLException { + return this.con.getCatalog(); + } + + @Override + public void setTransactionIsolation(int level) throws SQLException { + this.con.setTransactionIsolation(level); + } + + @Override + public int getTransactionIsolation() throws SQLException { + return this.con.getTransactionIsolation(); + } + + @Override + public SQLWarning getWarnings() throws SQLException { + return this.con.getWarnings(); + } + + @Override + public void clearWarnings() throws SQLException { + this.con.clearWarnings(); + } + + @Override + public Statement createStatement(int resultSetType, int resultSetConcurrency) throws SQLException { + return this.con.createStatement(resultSetType, resultSetConcurrency); + } + + @Override + public PreparedStatement prepareStatement(String sql, int resultSetType, int resultSetConcurrency) + throws SQLException { + return this.con.prepareStatement(sql, resultSetType, resultSetConcurrency); + } + + @Override + public CallableStatement prepareCall(String sql, int resultSetType, int resultSetConcurrency) throws SQLException { + return this.con.prepareCall(sql, resultSetType, resultSetConcurrency); + } + + @Override + public Map> getTypeMap() throws SQLException { + return this.con.getTypeMap(); + } + + @Override + public void setTypeMap(Map> map) throws SQLException { + this.con.setTypeMap(map); + } + + @Override + public void setHoldability(int holdability) throws SQLException { + this.con.setHoldability(holdability); + } + + @Override + public int getHoldability() throws SQLException { + return this.con.getHoldability(); + } + + @Override + public Savepoint setSavepoint() throws SQLException { + return this.con.setSavepoint(); + } + + @Override + public Savepoint setSavepoint(String name) throws SQLException { + return this.con.setSavepoint(name); + } + + @Override + public void rollback(Savepoint savepoint) throws SQLException { + this.con.rollback(savepoint); + } + + @Override + public void releaseSavepoint(Savepoint savepoint) throws SQLException { + this.con.releaseSavepoint(savepoint); + } + + @Override + public Statement createStatement(int resultSetType, int resultSetConcurrency, int resultSetHoldability) + throws SQLException { + return this.con.createStatement(resultSetType, resultSetConcurrency, resultSetHoldability); + } + + @Override + public PreparedStatement prepareStatement(String sql, int resultSetType, int resultSetConcurrency, + int resultSetHoldability) throws SQLException { + return this.con.prepareStatement(sql, resultSetType, resultSetConcurrency, resultSetHoldability); + } + + @Override + public CallableStatement prepareCall(String sql, int resultSetType, int resultSetConcurrency, + int resultSetHoldability) throws SQLException { + return this.con.prepareCall(sql, resultSetType, resultSetConcurrency, resultSetHoldability); + } + + @Override + public PreparedStatement prepareStatement(String sql, int autoGeneratedKeys) throws SQLException { + return this.con.prepareStatement(sql, autoGeneratedKeys); + } + + @Override + public PreparedStatement prepareStatement(String sql, int[] columnIndexes) throws SQLException { + return this.con.prepareStatement(sql, columnIndexes); + } + + @Override + public PreparedStatement prepareStatement(String sql, String[] columnNames) throws SQLException { + return this.con.prepareStatement(sql, columnNames); + } + + @Override + public Clob createClob() throws SQLException { + return this.con.createClob(); + } + + @Override + public Blob createBlob() throws SQLException { + return this.con.createBlob(); + } + + @Override + public NClob createNClob() throws SQLException { + return this.con.createNClob(); + } + + @Override + public SQLXML createSQLXML() throws SQLException { + return this.con.createSQLXML(); + } + + @Override + public boolean isValid(int timeout) throws SQLException { + return this.con.isValid(timeout); + } + + @Override + public void setClientInfo(String name, String value) throws SQLClientInfoException { + this.con.setClientInfo(name, value); + } + + @Override + public void setClientInfo(Properties properties) throws SQLClientInfoException { + this.con.setClientInfo(properties); + } + + @Override + public String getClientInfo(String name) throws SQLException { + return this.con.getClientInfo(name); + } + + @Override + public Properties getClientInfo() throws SQLException { + return this.con.getClientInfo(); + } + + @Override + public Array createArrayOf(String typeName, Object[] elements) throws SQLException { + return this.con.createArrayOf(typeName, elements); + } + + @Override + public Struct createStruct(String typeName, Object[] attributes) throws SQLException { + return this.con.createStruct(typeName, attributes); + } + + @Override + public void setSchema(String schema) throws SQLException { + this.con.setSchema(schema); + } + + @Override + public String getSchema() throws SQLException { + return this.con.getSchema(); + } + + @Override + public void abort(Executor executor) throws SQLException { + this.con.abort(executor); + } + + @Override + public void setNetworkTimeout(Executor executor, int milliseconds) throws SQLException { + this.con.setNetworkTimeout(executor, milliseconds); + } + + @Override + public int getNetworkTimeout() throws SQLException { + return this.con.getNetworkTimeout(); + } + + @Override + public T unwrap(Class iface) throws SQLException { + return this.con.unwrap(iface); + } + + @Override + public boolean isWrapperFor(Class iface) throws SQLException { + return this.con.isWrapperFor(iface); + } +} \ No newline at end of file diff --git a/src/main/java/io/supertokens/storage/postgresql/BulkImportProxyStorage.java b/src/main/java/io/supertokens/storage/postgresql/BulkImportProxyStorage.java new file mode 100644 index 00000000..12eeff8a --- /dev/null +++ b/src/main/java/io/supertokens/storage/postgresql/BulkImportProxyStorage.java @@ -0,0 +1,114 @@ +/* + * Copyright (c) 2024, VRAI Labs and/or its affiliates. All rights reserved. + * + * This software is licensed under the Apache License, Version 2.0 (the + * "License") as published by the Apache Software Foundation. + * + * You may not use this file except in compliance with the License. You may + * obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package io.supertokens.storage.postgresql; + +import io.supertokens.pluginInterface.exceptions.DbInitException; +import io.supertokens.pluginInterface.exceptions.StorageQueryException; +import io.supertokens.pluginInterface.exceptions.StorageTransactionLogicException; +import io.supertokens.pluginInterface.multitenancy.TenantIdentifier; +import io.supertokens.pluginInterface.multitenancy.exceptions.TenantOrAppNotFoundException; +import io.supertokens.pluginInterface.sqlStorage.TransactionConnection; + +import java.sql.Connection; +import java.sql.SQLException; +import java.util.List; + + +/** +* BulkImportProxyStorage is a class extending Start, serving as a Storage instance in the bulk import user cronjob. +* This cronjob extensively utilizes existing queries to import users, all of which internally operate within transactions. +* +* For the purpose of bulkimport cronjob, we aim to employ a single connection for all queries and rollback any operations in case of query failures. +* To achieve this, we override the startTransactionHelper method to utilize the same connection and prevent automatic query commits even upon transaction success. +* Subsequently, the cronjob is responsible for committing the transaction after ensuring the successful execution of all queries. +*/ + +public class BulkImportProxyStorage extends Start { + private BulkImportProxyConnection connection; + + public synchronized Connection getTransactionConnection() throws SQLException, StorageQueryException { + if (this.connection == null) { + Connection con = ConnectionPool.getConnectionForProxyStorage(this); + this.connection = new BulkImportProxyConnection(con); + connection.setTransactionIsolation(Connection.TRANSACTION_REPEATABLE_READ); + connection.setAutoCommit(false); + } + return this.connection; + } + + @Override + protected T startTransactionHelper(TransactionLogic logic, TransactionIsolationLevel isolationLevel) + throws StorageQueryException, StorageTransactionLogicException, SQLException, TenantOrAppNotFoundException { + return logic.mainLogicAndCommit(new TransactionConnection(getTransactionConnection())); + } + + @Override + public void commitTransaction(TransactionConnection con) throws StorageQueryException { + // We do not want to commit the queries when using the BulkImportProxyStorage to be able to rollback everything + // if any query fails while importing the user + //super.commitTransaction(con); + } + + @Override + public void initStorage(boolean shouldWait, List tenantIdentifiers) throws DbInitException { + super.initStorage(shouldWait, tenantIdentifiers); + + // `BulkImportProxyStorage` uses `BulkImportProxyConnection`, which overrides the `.commit()` method on the Connection object. + // The `initStorage()` method runs `select * from table_name limit 1` queries to check if the tables exist but these queries + // don't get committed due to the overridden `.commit()`, so we need to manually commit the transaction to remove any locks on the tables. + + // Without this commit, a call to `select * from bulk_import_users limit 1` in `doesTableExist()` locks the `bulk_import_users` table, + try { + this.commitTransactionForBulkImportProxyStorage(); + } catch (StorageQueryException e) { + throw new DbInitException(e); + } + } + + @Override + public void closeConnectionForBulkImportProxyStorage() throws StorageQueryException { + try { + if (this.connection != null) { + this.connection.closeForBulkImportProxyStorage(); + this.connection = null; + } + ConnectionPool.close(this); + } catch (SQLException e) { + throw new StorageQueryException(e); + } + } + + @Override + public void commitTransactionForBulkImportProxyStorage() throws StorageQueryException { + try { + if (this.connection != null) { + this.connection.commitForBulkImportProxyStorage(); + } + } catch (SQLException e) { + throw new StorageQueryException(e); + } + } + + @Override + public void rollbackTransactionForBulkImportProxyStorage() throws StorageQueryException { + try { + this.connection.rollbackForBulkImportProxyStorage(); + } catch (SQLException e) { + throw new StorageQueryException(e); + } + } +} diff --git a/src/main/java/io/supertokens/storage/postgresql/ConnectionPool.java b/src/main/java/io/supertokens/storage/postgresql/ConnectionPool.java index acab61f5..70af1f53 100644 --- a/src/main/java/io/supertokens/storage/postgresql/ConnectionPool.java +++ b/src/main/java/io/supertokens/storage/postgresql/ConnectionPool.java @@ -206,7 +206,7 @@ static void initPool(Start start, boolean shouldWait, PostConnectCallback postCo } } - public static Connection getConnection(Start start) throws SQLException, StorageQueryException { + private static Connection getNewConnection(Start start) throws SQLException, StorageQueryException { if (getInstance(start) == null) { throw new IllegalStateException("Please call initPool before getConnection"); } @@ -219,6 +219,17 @@ public static Connection getConnection(Start start) throws SQLException, Storage return getInstance(start).hikariDataSource.getConnection(); } + public static Connection getConnectionForProxyStorage(Start start) throws SQLException, StorageQueryException { + return getNewConnection(start); + } + + public static Connection getConnection(Start start) throws SQLException, StorageQueryException { + if (start instanceof BulkImportProxyStorage) { + return ((BulkImportProxyStorage) start).getTransactionConnection(); + } + return getNewConnection(start); + } + static void close(Start start) { if (getInstance(start) == null) { return; diff --git a/src/main/java/io/supertokens/storage/postgresql/QueryExecutorTemplate.java b/src/main/java/io/supertokens/storage/postgresql/QueryExecutorTemplate.java index 179a871b..9b408b5e 100644 --- a/src/main/java/io/supertokens/storage/postgresql/QueryExecutorTemplate.java +++ b/src/main/java/io/supertokens/storage/postgresql/QueryExecutorTemplate.java @@ -51,12 +51,22 @@ static int update(Start start, String QUERY, PreparedStatementValueSetter setter } } - static int update(Connection con, String QUERY, PreparedStatementValueSetter setter) - throws SQLException, StorageQueryException { + static int update(Connection con, String QUERY, PreparedStatementValueSetter setter) throws SQLException { try (PreparedStatement pst = con.prepareStatement(QUERY)) { setter.setValues(pst); return pst.executeUpdate(); } } + static T update(Start start, String QUERY, PreparedStatementValueSetter setter, ResultSetValueExtractor mapper) + throws SQLException, StorageQueryException { + try (Connection con = ConnectionPool.getConnection(start)) { + try (PreparedStatement pst = con.prepareStatement(QUERY)) { + setter.setValues(pst); + try (ResultSet result = pst.executeQuery()) { + return mapper.extract(result); + } + } + } + } } diff --git a/src/main/java/io/supertokens/storage/postgresql/Start.java b/src/main/java/io/supertokens/storage/postgresql/Start.java index ade74a27..b2ce5963 100644 --- a/src/main/java/io/supertokens/storage/postgresql/Start.java +++ b/src/main/java/io/supertokens/storage/postgresql/Start.java @@ -25,11 +25,17 @@ import io.supertokens.pluginInterface.authRecipe.AuthRecipeUserInfo; import io.supertokens.pluginInterface.authRecipe.LoginMethod; import io.supertokens.pluginInterface.authRecipe.sqlStorage.AuthRecipeSQLStorage; +import io.supertokens.pluginInterface.bulkimport.BulkImportStorage; +import io.supertokens.pluginInterface.bulkimport.BulkImportUser; +import io.supertokens.pluginInterface.bulkimport.exceptions.BulkImportBatchInsertException; +import io.supertokens.pluginInterface.bulkimport.exceptions.BulkImportTransactionRolledBackException; +import io.supertokens.pluginInterface.bulkimport.sqlStorage.BulkImportSQLStorage; import io.supertokens.pluginInterface.dashboard.DashboardSearchTags; import io.supertokens.pluginInterface.dashboard.DashboardSessionInfo; import io.supertokens.pluginInterface.dashboard.DashboardUser; import io.supertokens.pluginInterface.dashboard.exceptions.UserIdNotFoundException; import io.supertokens.pluginInterface.dashboard.sqlStorage.DashboardSQLStorage; +import io.supertokens.pluginInterface.emailpassword.EmailPasswordImportUser; import io.supertokens.pluginInterface.emailpassword.PasswordResetTokenInfo; import io.supertokens.pluginInterface.emailpassword.exceptions.DuplicateEmailException; import io.supertokens.pluginInterface.emailpassword.exceptions.DuplicatePasswordResetTokenException; @@ -64,12 +70,14 @@ import io.supertokens.pluginInterface.oauth.exception.OAuthClientNotFoundException; import io.supertokens.pluginInterface.passwordless.PasswordlessCode; import io.supertokens.pluginInterface.passwordless.PasswordlessDevice; +import io.supertokens.pluginInterface.passwordless.PasswordlessImportUser; import io.supertokens.pluginInterface.passwordless.exception.*; import io.supertokens.pluginInterface.passwordless.sqlStorage.PasswordlessSQLStorage; import io.supertokens.pluginInterface.session.SessionInfo; import io.supertokens.pluginInterface.session.SessionStorage; import io.supertokens.pluginInterface.session.sqlStorage.SessionSQLStorage; import io.supertokens.pluginInterface.sqlStorage.TransactionConnection; +import io.supertokens.pluginInterface.thirdparty.ThirdPartyImportUser; import io.supertokens.pluginInterface.thirdparty.exception.DuplicateThirdPartyUserException; import io.supertokens.pluginInterface.thirdparty.sqlStorage.ThirdPartySQLStorage; import io.supertokens.pluginInterface.totp.TOTPDevice; @@ -103,13 +111,11 @@ import org.slf4j.LoggerFactory; import javax.annotation.Nonnull; +import java.sql.BatchUpdateException; import java.sql.Connection; import java.sql.SQLException; import java.sql.SQLTransactionRollbackException; -import java.util.ArrayList; -import java.util.HashMap; -import java.util.List; -import java.util.Set; +import java.util.*; import static io.supertokens.storage.postgresql.QueryExecutorTemplate.execute; @@ -117,7 +123,7 @@ public class Start implements SessionSQLStorage, EmailPasswordSQLStorage, EmailVerificationSQLStorage, ThirdPartySQLStorage, JWTRecipeSQLStorage, PasswordlessSQLStorage, UserMetadataSQLStorage, UserRolesSQLStorage, UserIdMappingStorage, UserIdMappingSQLStorage, MultitenancyStorage, MultitenancySQLStorage, DashboardSQLStorage, TOTPSQLStorage, - ActiveUsersStorage, ActiveUsersSQLStorage, AuthRecipeSQLStorage, OAuthStorage { + ActiveUsersStorage, ActiveUsersSQLStorage, AuthRecipeSQLStorage, OAuthStorage, BulkImportSQLStorage { // these configs are protected from being modified / viewed by the dev using the SuperTokens // SaaS. If the core is not running in SuperTokens SaaS, this array has no effect. @@ -161,6 +167,29 @@ public STORAGE_TYPE getType() { return STORAGE_TYPE.SQL; } + @Override + public Storage createBulkImportProxyStorageInstance() { + return new BulkImportProxyStorage(); + } + + @Override + public void closeConnectionForBulkImportProxyStorage() throws StorageQueryException { + throw new UnsupportedOperationException( + "closeConnectionForBulkImportProxyStorage should only be called from BulkImportProxyStorage"); + } + + @Override + public void commitTransactionForBulkImportProxyStorage() throws StorageQueryException { + throw new UnsupportedOperationException( + "commitTransactionForBulkImportProxyStorage should only be called from BulkImportProxyStorage"); + } + + @Override + public void rollbackTransactionForBulkImportProxyStorage() throws StorageQueryException { + throw new UnsupportedOperationException( + "rollbackTransactionForBulkImportProxyStorage should only be called from BulkImportProxyStorage"); + } + @Override public void loadConfig(JsonObject configJson, Set logLevels, TenantIdentifier tenantIdentifier) throws InvalidConfigException { @@ -271,7 +300,8 @@ public T startTransaction(TransactionLogic logic, TransactionIsolationLev tries++; try { return startTransactionHelper(logic, isolationLevel); - } catch (SQLException | StorageQueryException | StorageTransactionLogicException e) { + } catch (SQLException | StorageQueryException | StorageTransactionLogicException | + TenantOrAppNotFoundException e) { Throwable actualException = e; if (e instanceof StorageQueryException) { actualException = e.getCause(); @@ -308,6 +338,13 @@ public T startTransaction(TransactionLogic logic, TransactionIsolationLev if ((isPSQLRollbackException || isDeadlockException) && tries < NUM_TRIES) { try { + if(this instanceof BulkImportProxyStorage){ + throw new StorageTransactionLogicException(new BulkImportTransactionRolledBackException(e)); + // if the current instance is of BulkImportProxyStorage, that means we are doing a bulk import + // which uses nested transactions. With MySQL this retry logic doesn't going to work, we have + // to retry the whole "big" transaction, not just the innermost, current one. + // @see BulkImportTransactionRolledBackException for more explanation. + } Thread.sleep((long) (10 + (250 + Math.min(Math.pow(2, tries), 3000)) * Math.random())); } catch (InterruptedException ignored) { } @@ -324,14 +361,16 @@ public T startTransaction(TransactionLogic logic, TransactionIsolationLev throw (StorageQueryException) e; } else if (e instanceof StorageTransactionLogicException) { throw (StorageTransactionLogicException) e; + } else if (e instanceof TenantOrAppNotFoundException) { + throw new StorageTransactionLogicException(e); } throw new StorageQueryException(e); } } } - private T startTransactionHelper(TransactionLogic logic, TransactionIsolationLevel isolationLevel) - throws StorageQueryException, StorageTransactionLogicException, SQLException { + protected T startTransactionHelper(TransactionLogic logic, TransactionIsolationLevel isolationLevel) + throws StorageQueryException, StorageTransactionLogicException, SQLException, TenantOrAppNotFoundException { Connection con = null; Integer defaultTransactionIsolation = null; try { @@ -771,13 +810,6 @@ public boolean isUserIdBeingUsedInNonAuthRecipe(AppIdentifier appIdentifier, Str } catch (SQLException e) { throw new StorageQueryException(e); } - } else if (className.equals(TOTPStorage.class.getName())) { - try { - TOTPDevice[] devices = TOTPQueries.getDevices(this, appIdentifier, userId); - return devices.length > 0; - } catch (SQLException e) { - throw new StorageQueryException(e); - } } else if (className.equals(JWTRecipeStorage.class.getName())) { return false; } else if (className.equals(ActiveUsersStorage.class.getName())) { @@ -787,6 +819,73 @@ public boolean isUserIdBeingUsedInNonAuthRecipe(AppIdentifier appIdentifier, Str } } + @Override + public Map> findNonAuthRecipesWhereForUserIdsUsed(AppIdentifier appIdentifier, + List userIds) + throws StorageQueryException { + try { + Map> sessionHandlesByUserId = SessionQueries.getAllNonExpiredSessionHandlesForUsers(this, appIdentifier, userIds); + Map> rolesByUserIds = UserRolesQueries.getRolesForUsers(this, appIdentifier, userIds); + Map userMetadatasByIds = UserMetadataQueries.getMultipleUserMetadatas(this, appIdentifier, userIds); + Set userIdsUsedInEmailVerification = EmailVerificationQueries.findUserIdsBeingUsedForEmailVerification(this, appIdentifier, userIds); + Map> devicesByUserIds = TOTPQueries.getDevicesForMultipleUsers(this, appIdentifier, userIds); + Map lastActivesByUserIds = ActiveUsersQueries.getLastActiveByMultipleUserIds(this, appIdentifier, userIds); + + Map> nonAuthRecipeClassnamesByUserIds = new HashMap<>(); + //session recipe + for(String userId: sessionHandlesByUserId.keySet()){ + if(!nonAuthRecipeClassnamesByUserIds.containsKey(userId)){ + nonAuthRecipeClassnamesByUserIds.put(userId, new ArrayList<>()); + } + nonAuthRecipeClassnamesByUserIds.get(userId).add(SessionStorage.class.getName()); + } + + //role recipe + for(String userId: rolesByUserIds.keySet()){ + if(!nonAuthRecipeClassnamesByUserIds.containsKey(userId)){ + nonAuthRecipeClassnamesByUserIds.put(userId, new ArrayList<>()); + } + nonAuthRecipeClassnamesByUserIds.get(userId).add(UserRolesStorage.class.getName()); + } + + //usermetadata recipe + for(String userId: userMetadatasByIds.keySet()){ + if(!nonAuthRecipeClassnamesByUserIds.containsKey(userId)){ + nonAuthRecipeClassnamesByUserIds.put(userId, new ArrayList<>()); + } + nonAuthRecipeClassnamesByUserIds.get(userId).add(UserMetadataStorage.class.getName()); + } + + //emailverification recipe + for(String userId: userIdsUsedInEmailVerification){ + if(!nonAuthRecipeClassnamesByUserIds.containsKey(userId)){ + nonAuthRecipeClassnamesByUserIds.put(userId, new ArrayList<>()); + } + nonAuthRecipeClassnamesByUserIds.get(userId).add(EmailVerificationStorage.class.getName()); + } + + //totp recipe + for(String userId: devicesByUserIds.keySet()){ + if(!nonAuthRecipeClassnamesByUserIds.containsKey(userId)){ + nonAuthRecipeClassnamesByUserIds.put(userId, new ArrayList<>()); + } + nonAuthRecipeClassnamesByUserIds.get(userId).add(TOTPStorage.class.getName()); + } + + //active users + for(String userId: lastActivesByUserIds.keySet()){ + if(!nonAuthRecipeClassnamesByUserIds.containsKey(userId)){ + nonAuthRecipeClassnamesByUserIds.put(userId, new ArrayList<>()); + } + nonAuthRecipeClassnamesByUserIds.get(userId).add(ActiveUsersStorage.class.getName()); + } + + return nonAuthRecipeClassnamesByUserIds; + } catch (SQLException | StorageTransactionLogicException exc) { + throw new StorageQueryException(exc); + } + } + @TestOnly @Override public void addInfoToNonAuthRecipesBasedOnUserId(TenantIdentifier tenantIdentifier, String className, String userId) @@ -874,6 +973,9 @@ public void addInfoToNonAuthRecipesBasedOnUserId(TenantIdentifier tenantIdentifi } } else if (className.equals(JWTRecipeStorage.class.getName())) { /* Since JWT recipe tables do not store userId we do not add any data to them */ + + } else if (className.equals(BulkImportStorage.class.getName())){ + //ignore } else if (className.equals(OAuthStorage.class.getName())) { /* Since OAuth recipe tables do not store userId we do not add any data to them */ } else if (className.equals(ActiveUsersStorage.class.getName())) { @@ -931,6 +1033,60 @@ public AuthRecipeUserInfo signUp(TenantIdentifier tenantIdentifier, String id, S } } + @Override + public void signUpMultipleViaBulkImport_Transaction(TransactionConnection connection, List users) + throws StorageQueryException, StorageTransactionLogicException { + try { + Connection sqlConnection = (Connection) connection.getConnection(); + EmailPasswordQueries.signUpMultipleForBulkImport_Transaction(this, sqlConnection, users); + } catch (StorageQueryException | SQLException | StorageTransactionLogicException e) { + Throwable actual = e.getCause(); + if (actual instanceof BatchUpdateException) { + BatchUpdateException batchUpdateException = (BatchUpdateException) actual; + Map errorByPosition = new HashMap<>(); + SQLException nextException = batchUpdateException.getNextException(); + while (nextException != null) { + + if (nextException instanceof PSQLException) { + PostgreSQLConfig config = Config.getConfig(this); + ServerErrorMessage serverMessage = ((PSQLException) nextException).getServerErrorMessage(); + + int position = getErroneousEntryPosition(batchUpdateException); + if (isUniqueConstraintError(serverMessage, config.getEmailPasswordUserToTenantTable(), + "email")) { + + errorByPosition.put(users.get(position).userId, new DuplicateEmailException()); + + } else if (isPrimaryKeyError(serverMessage, config.getThirdPartyUsersTable()) + || isPrimaryKeyError(serverMessage, config.getUsersTable()) + || isPrimaryKeyError(serverMessage, config.getThirdPartyUserToTenantTable()) + || isPrimaryKeyError(serverMessage, config.getAppIdToUserIdTable())) { + errorByPosition.put(users.get(position).userId, + new io.supertokens.pluginInterface.thirdparty.exception.DuplicateUserIdException()); + } else if (isForeignKeyConstraintError(serverMessage, config.getAppIdToUserIdTable(), "app_id")) { + errorByPosition.put(users.get(position).userId, new TenantOrAppNotFoundException(users.get(position).tenantIdentifier.toAppIdentifier())); + } else if (isForeignKeyConstraintError(serverMessage, config.getUsersTable(), "tenant_id")) { + errorByPosition.put(users.get(position).userId,new TenantOrAppNotFoundException(users.get(position).tenantIdentifier)); + } + + } + nextException = nextException.getNextException(); + } + throw new StorageTransactionLogicException(new BulkImportBatchInsertException("emailpassword errors", errorByPosition)); + } + throw new StorageQueryException(e); + } + } + + private static int getErroneousEntryPosition(BatchUpdateException batchUpdateException) { + String errorMessage = batchUpdateException.getMessage(); + String searchFor = "Batch entry "; + int searchForIndex = errorMessage.indexOf("Batch entry "); + String entryIndex = errorMessage.substring(searchForIndex + searchFor.length(), errorMessage.indexOf(" ", searchForIndex + searchFor.length())); + int position = Integer.parseInt(entryIndex); + return position; + } + @Override public void deleteEmailPasswordUser_Transaction(TransactionConnection con, AppIdentifier appIdentifier, String userId, boolean deleteUserIdMappingToo) @@ -1108,6 +1264,35 @@ public void updateIsEmailVerified_Transaction(AppIdentifier appIdentifier, Trans } } + @Override + public void updateMultipleIsEmailVerified_Transaction(AppIdentifier appIdentifier, TransactionConnection con, + Map emailToUserId, boolean isEmailVerified) + throws StorageQueryException, TenantOrAppNotFoundException { + Connection sqlCon = (Connection) con.getConnection(); + try { + EmailVerificationQueries.updateMultipleUsersIsEmailVerified_Transaction(this, sqlCon, appIdentifier, + emailToUserId, isEmailVerified); + } catch (SQLException e) { + if (e instanceof PSQLException) { + PostgreSQLConfig config = Config.getConfig(this); + ServerErrorMessage serverMessage = ((PSQLException) e).getServerErrorMessage(); + + if (isForeignKeyConstraintError(serverMessage, config.getEmailVerificationTable(), "app_id")) { + throw new TenantOrAppNotFoundException(appIdentifier); + } + } + + boolean isPSQLPrimKeyError = e instanceof PSQLException && isPrimaryKeyError( + ((PSQLException) e).getServerErrorMessage(), + Config.getConfig(this).getEmailVerificationTable()); + + if (!isEmailVerified || !isPSQLPrimKeyError) { + throw new StorageQueryException(e); + } + // we do not throw an error since the email is already verified + } + } + @Override public void deleteEmailVerificationUserInfo_Transaction(TransactionConnection con, AppIdentifier appIdentifier, String userId) @@ -1215,6 +1400,13 @@ public void updateIsEmailVerifiedToExternalUserId(AppIdentifier appIdentifier, S externalUserId); } + @Override + public void updateMultipleIsEmailVerifiedToExternalUserIds(AppIdentifier appIdentifier, + Map supertokensUserIdToExternalUserId) + throws StorageQueryException { + EmailVerificationQueries.updateMultipleIsEmailVerifiedToExternalUserIds(this, appIdentifier, supertokensUserIdToExternalUserId); + } + @Override public void deleteExpiredPasswordResetTokens() throws StorageQueryException { try { @@ -1292,6 +1484,53 @@ public void deleteThirdPartyUser_Transaction(TransactionConnection con, AppIdent } } + @Override + public void importThirdPartyUsers_Transaction(TransactionConnection con, + List usersToImport) + throws StorageQueryException, StorageTransactionLogicException, TenantOrAppNotFoundException { + try { + Connection sqlCon = (Connection) con.getConnection(); + ThirdPartyQueries.importUser_Transaction(this, sqlCon, usersToImport); + } catch (SQLException e) { + Throwable actual = e.getCause(); + if (actual instanceof BatchUpdateException) { + BatchUpdateException batchUpdateException = (BatchUpdateException) actual; + Map errorByPosition = new HashMap<>(); + SQLException nextException = batchUpdateException.getNextException(); + while (nextException != null) { + + if (nextException instanceof PSQLException) { + PostgreSQLConfig config = Config.getConfig(this); + ServerErrorMessage serverMessage = ((PSQLException) nextException).getServerErrorMessage(); + + int position = getErroneousEntryPosition(batchUpdateException); + if (isUniqueConstraintError(serverMessage, config.getEmailPasswordUserToTenantTable(), + "third_party_user_id")) { + + errorByPosition.put(usersToImport.get(position).userId, new DuplicateThirdPartyUserException()); + + } else if (isPrimaryKeyError(serverMessage, config.getThirdPartyUsersTable()) + || isPrimaryKeyError(serverMessage, config.getUsersTable()) + || isPrimaryKeyError(serverMessage, config.getThirdPartyUserToTenantTable()) + || isPrimaryKeyError(serverMessage, config.getAppIdToUserIdTable())) { + errorByPosition.put(usersToImport.get(position).userId, + new io.supertokens.pluginInterface.thirdparty.exception.DuplicateUserIdException()); + } + else if (isForeignKeyConstraintError(serverMessage, config.getAppIdToUserIdTable(), "app_id")) { + throw new TenantOrAppNotFoundException(usersToImport.get(position).tenantIdentifier.toAppIdentifier()); + + } else if (isForeignKeyConstraintError(serverMessage, config.getUsersTable(), "tenant_id")) { + throw new TenantOrAppNotFoundException(usersToImport.get(position).tenantIdentifier); + } + } + nextException = nextException.getNextException(); + } + throw new StorageTransactionLogicException(new BulkImportBatchInsertException("thirdparty errors", errorByPosition)); + } + throw new StorageQueryException(e); + } + } + @Override public long getUsersCount(TenantIdentifier tenantIdentifier, RECIPE_ID[] includeRecipeIds) throws StorageQueryException { @@ -1383,6 +1622,16 @@ public boolean doesUserIdExist(TenantIdentifier tenantIdentifier, String userId) } } + @Override + public List findExistingUserIds(AppIdentifier appIdentifier, List userIds) + throws StorageQueryException { + try { + return GeneralQueries.findUserIdsThatExist(this, appIdentifier, userIds); + } catch (SQLException e) { + throw new StorageQueryException(e); + } + } + @Override public AuthRecipeUserInfo getPrimaryUserById(AppIdentifier appIdentifier, String userId) throws StorageQueryException { @@ -1810,6 +2059,61 @@ public void deletePasswordlessUser_Transaction(TransactionConnection con, AppIde } } + @Override + public void importPasswordlessUsers_Transaction(TransactionConnection con, + List users) + throws StorageQueryException, TenantOrAppNotFoundException { + try { + Connection sqlCon = (Connection) con.getConnection(); + PasswordlessQueries.importUsers_Transaction(sqlCon, this, users); + } catch (SQLException e) { + if (e instanceof BatchUpdateException) { + Throwable actual = e.getCause(); + if (actual instanceof BatchUpdateException) { + BatchUpdateException batchUpdateException = (BatchUpdateException) actual; + Map errorByPosition = new HashMap<>(); + SQLException nextException = batchUpdateException.getNextException(); + while (nextException != null) { + + if (nextException instanceof PSQLException) { + PostgreSQLConfig config = Config.getConfig(this); + ServerErrorMessage serverMessage = ((PSQLException) nextException).getServerErrorMessage(); + + int position = getErroneousEntryPosition(batchUpdateException); + + if (isPrimaryKeyError(serverMessage, config.getPasswordlessUsersTable()) + || isPrimaryKeyError(serverMessage, config.getUsersTable()) + || isPrimaryKeyError(serverMessage, config.getPasswordlessUserToTenantTable()) + || isPrimaryKeyError(serverMessage, config.getAppIdToUserIdTable())) { + errorByPosition.put(users.get(position).userId, new DuplicateUserIdException()); + } + if (isUniqueConstraintError(serverMessage, config.getPasswordlessUserToTenantTable(), + "email")) { + errorByPosition.put(users.get(position).userId, new DuplicateEmailException()); + + } else if (isUniqueConstraintError(serverMessage, config.getPasswordlessUserToTenantTable(), + "phone_number")) { + errorByPosition.put(users.get(position).userId, new DuplicatePhoneNumberException()); + + } else if (isForeignKeyConstraintError(serverMessage, config.getAppIdToUserIdTable(), + "app_id")) { + throw new TenantOrAppNotFoundException(users.get(position).tenantIdentifier.toAppIdentifier()); + + } else if (isForeignKeyConstraintError(serverMessage, config.getUsersTable(), + "tenant_id")) { + throw new TenantOrAppNotFoundException(users.get(position).tenantIdentifier.toAppIdentifier()); + } + } + nextException = nextException.getNextException(); + } + throw new StorageQueryException( + new BulkImportBatchInsertException("passwordless errors", errorByPosition)); + } + throw new StorageQueryException(e); + } + } + } + @Override public PasswordlessDevice getDevice(TenantIdentifier tenantIdentifier, String deviceIdHash) throws StorageQueryException { @@ -1902,6 +2206,18 @@ public JsonObject getUserMetadata_Transaction(AppIdentifier appIdentifier, Trans } } + @Override + public Map getMultipleUsersMetadatas_Transaction(AppIdentifier appIdentifier, TransactionConnection + con, List userIds) + throws StorageQueryException { + Connection sqlCon = (Connection) con.getConnection(); + try { + return UserMetadataQueries.getMultipleUsersMetadatas_Transaction(this, sqlCon, appIdentifier, userIds); + } catch (SQLException e) { + throw new StorageQueryException(e); + } + } + @Override public int setUserMetadata_Transaction(AppIdentifier appIdentifier, TransactionConnection con, String userId, JsonObject metadata) @@ -1923,6 +2239,26 @@ public int setUserMetadata_Transaction(AppIdentifier appIdentifier, TransactionC } } + @Override + public void setMultipleUsersMetadatas_Transaction(AppIdentifier appIdentifier, TransactionConnection con, + Map metadataByUserId) + throws StorageQueryException, TenantOrAppNotFoundException { + Connection sqlCon = (Connection) con.getConnection(); + try { + UserMetadataQueries.setMultipleUsersMetadatas_Transaction(this, sqlCon, appIdentifier, metadataByUserId); + } catch (SQLException e) { + if (e instanceof PSQLException) { + PostgreSQLConfig config = Config.getConfig(this); + ServerErrorMessage serverMessage = ((PSQLException) e).getServerErrorMessage(); + + if (isForeignKeyConstraintError(serverMessage, config.getUserMetadataTable(), "app_id")) { + throw new TenantOrAppNotFoundException(appIdentifier); + } + } + throw new StorageQueryException(e); + } + } + @Override public int deleteUserMetadata_Transaction(TransactionConnection con, AppIdentifier appIdentifier, String userId) throws StorageQueryException { @@ -1963,6 +2299,17 @@ public void addRoleToUser(TenantIdentifier tenantIdentifier, String userId, Stri } throw new StorageQueryException(e); } + } + + @Override + public void addRolesToUsers_Transaction(TransactionConnection connection, + Map>> rolesToUserByTenants) + throws StorageQueryException { + try { + UserRolesQueries.addRolesToUsers_Transaction(this, (Connection) connection.getConnection(), rolesToUserByTenants); + } catch (SQLException e) { + throw new StorageQueryException(e); + } } @@ -2170,6 +2517,17 @@ public boolean doesRoleExist_Transaction(AppIdentifier appIdentifier, Transactio } } + @Override + public List doesMultipleRoleExist_Transaction(AppIdentifier appIdentifier, TransactionConnection con, + List roles) throws StorageQueryException { + Connection sqlCon = (Connection) con.getConnection(); + try { + return UserRolesQueries.doesMultipleRoleExist_transaction(this, sqlCon, appIdentifier, roles); + } catch (SQLException e) { + throw new StorageQueryException(e); + } + } + @Override public void createUserIdMapping(AppIdentifier appIdentifier, String superTokensUserId, String externalUserId, @org.jetbrains.annotations.Nullable String externalUserIdInfo) @@ -2205,6 +2563,19 @@ public void createUserIdMapping(AppIdentifier appIdentifier, String superTokensU } + @Override + public void createBulkUserIdMapping(AppIdentifier appIdentifier, + Map superTokensUserIdToExternalUserId) + throws StorageQueryException { + try { + + UserIdMappingQueries.createBulkUserIdMapping(this, appIdentifier, superTokensUserIdToExternalUserId); + } catch (SQLException e) { + throw new StorageQueryException(e); + } + + } + @Override public boolean deleteUserIdMapping(AppIdentifier appIdentifier, String userId, boolean isSuperTokensUserId) throws StorageQueryException { @@ -2709,6 +3080,26 @@ public TOTPDevice createDevice_Transaction(TransactionConnection con, AppIdentif } } + @Override + public void createDevices_Transaction(TransactionConnection con, AppIdentifier appIdentifier, + List devices) + throws StorageQueryException, TenantOrAppNotFoundException { + Connection sqlCon = (Connection) con.getConnection(); + try { + TOTPQueries.createDevices_Transaction(this, sqlCon, appIdentifier, devices); + } catch (SQLException e) { + Exception actualException = e; + + if (actualException instanceof PSQLException) { + ServerErrorMessage errMsg = ((PSQLException) actualException).getServerErrorMessage(); + if (isForeignKeyConstraintError(errMsg, Config.getConfig(this).getTotpUsersTable(), "app_id")) { + throw new TenantOrAppNotFoundException(appIdentifier); + } + } + throw new StorageQueryException(e); + } + } + @Override public TOTPDevice getDeviceByName_Transaction(TransactionConnection con, AppIdentifier appIdentifier, String userId, String deviceName) throws StorageQueryException { @@ -2908,6 +3299,18 @@ public AuthRecipeUserInfo getPrimaryUserById_Transaction(AppIdentifier appIdenti } } + @Override + public List getPrimaryUsersByIds_Transaction(AppIdentifier appIdentifier, TransactionConnection con, + List userIds) + throws StorageQueryException { + try { + Connection sqlCon = (Connection) con.getConnection(); + return GeneralQueries.getPrimaryUserInfosForUserIds_Transaction(this, sqlCon, appIdentifier, userIds); + } catch (SQLException e) { + throw new StorageQueryException(e); + } + } + @Override public AuthRecipeUserInfo[] listPrimaryUsersByEmail_Transaction(AppIdentifier appIdentifier, TransactionConnection con, String email) @@ -2920,6 +3323,19 @@ public AuthRecipeUserInfo[] listPrimaryUsersByEmail_Transaction(AppIdentifier ap } } + @Override + public AuthRecipeUserInfo[] listPrimaryUsersByMultipleEmailsOrPhoneNumbersOrThirdparty_Transaction( + AppIdentifier appIdentifier, TransactionConnection con, List emails, List phones, + Map thirdpartyIdToThirdpartyUserId) throws StorageQueryException { + try { + Connection sqlCon = (Connection) con.getConnection(); + return GeneralQueries.listPrimaryUsersByMultipleEmailsOrPhonesOrThirdParty_Transaction(this, sqlCon, + appIdentifier, emails, phones, thirdpartyIdToThirdpartyUserId); + } catch (SQLException e) { + throw new StorageQueryException(e); + } + } + @Override public AuthRecipeUserInfo[] listPrimaryUsersByPhoneNumber_Transaction(AppIdentifier appIdentifier, TransactionConnection con, @@ -2975,6 +3391,17 @@ public void makePrimaryUser_Transaction(AppIdentifier appIdentifier, Transaction } } + @Override + public void makePrimaryUsers_Transaction(AppIdentifier appIdentifier, TransactionConnection con, + List userIds) throws StorageQueryException { + try { + Connection sqlCon = (Connection) con.getConnection(); + GeneralQueries.makePrimaryUsers_Transaction(this, sqlCon, appIdentifier, userIds); + } catch (SQLException e) { + throw new StorageQueryException(e); + } + } + @Override public void linkAccounts_Transaction(AppIdentifier appIdentifier, TransactionConnection con, String recipeUserId, String primaryUserId) throws StorageQueryException { @@ -2988,6 +3415,19 @@ public void linkAccounts_Transaction(AppIdentifier appIdentifier, TransactionCon } } + @Override + public void linkMultipleAccounts_Transaction(AppIdentifier appIdentifier, TransactionConnection con, + Map recipeUserIdByPrimaryUserId) + throws StorageQueryException { + try { + Connection sqlCon = (Connection) con.getConnection(); + GeneralQueries.linkMultipleAccounts_Transaction(this, sqlCon, appIdentifier, recipeUserIdByPrimaryUserId); + } catch (SQLException e) { + throw new StorageQueryException(e); + } + + } + @Override public void unlinkAccounts_Transaction(AppIdentifier appIdentifier, TransactionConnection con, String primaryUserId, String recipeUserId) @@ -3079,6 +3519,27 @@ public UserIdMapping[] getUserIdMapping_Transaction(TransactionConnection con, A } } + @Override + public List getMultipleUserIdMapping_Transaction(TransactionConnection connection, + AppIdentifier appIdentifier, List userIds, + boolean isSupertokensIds) + throws StorageQueryException { + try { + Connection sqlCon = (Connection) connection.getConnection(); + List result; + if(isSupertokensIds){ + result = UserIdMappingQueries.getMultipleUserIdMappingWithSupertokensUserId_Transaction(this, + sqlCon, appIdentifier, userIds); + } else { + result = UserIdMappingQueries.getMultipleUserIdMappingWithExternalUserId_Transaction(this, + sqlCon, appIdentifier, userIds); + } + return result; + } catch (SQLException e) { + throw new StorageQueryException(e); + } + } + @Override public int getUsersCountWithMoreThanOneLoginMethodOrTOTPEnabled(AppIdentifier appIdentifier) throws StorageQueryException { @@ -3126,6 +3587,27 @@ public int getDbActivityCount(String dbname) throws SQLException, StorageQueryEx }); } + @Override + public void addBulkImportUsers(AppIdentifier appIdentifier, List users) + throws StorageQueryException, + TenantOrAppNotFoundException, + io.supertokens.pluginInterface.bulkimport.exceptions.DuplicateUserIdException { + try { + BulkImportQueries.insertBulkImportUsers(this, appIdentifier, users); + } catch (SQLException e) { + if (e instanceof PSQLException) { + ServerErrorMessage serverErrorMessage = ((PSQLException) e).getServerErrorMessage(); + if (isPrimaryKeyError(serverErrorMessage, Config.getConfig(this).getBulkImportUsersTable())) { + throw new io.supertokens.pluginInterface.bulkimport.exceptions.DuplicateUserIdException(); + } + if (isForeignKeyConstraintError(serverErrorMessage, Config.getConfig(this).getBulkImportUsersTable(), + "app_id")) { + throw new TenantOrAppNotFoundException(appIdentifier); + } + } + } + } + @Override public OAuthClient getOAuthClientById(AppIdentifier appIdentifier, String clientId) throws StorageQueryException, OAuthClientNotFoundException { @@ -3159,6 +3641,16 @@ public void addOrUpdateOauthClient(AppIdentifier appIdentifier, String clientId, } } + @Override + public List getBulkImportUsers(AppIdentifier appIdentifier, @Nonnull Integer limit, @Nullable BULK_IMPORT_USER_STATUS status, + @Nullable String bulkImportUserId, @Nullable Long createdAt) throws StorageQueryException { + try { + return BulkImportQueries.getBulkImportUsers(this, appIdentifier, limit, status, bulkImportUserId, createdAt); + } catch (SQLException e) { + throw new StorageQueryException(e); + } + } + @Override public List getOAuthClients(AppIdentifier appIdentifier, List clientIds) throws StorageQueryException { try { @@ -3168,6 +3660,31 @@ public List getOAuthClients(AppIdentifier appIdentifier, List bulkImportUserIdToErrorMessage) + throws StorageQueryException { + Connection sqlCon = (Connection) con.getConnection(); + try { + BulkImportQueries.updateMultipleBulkImportUsersStatusToError_Transaction(this, sqlCon, appIdentifier, + bulkImportUserIdToErrorMessage); + } catch (SQLException e) { + throw new StorageQueryException(e); + } + } + @Override public boolean revokeOAuthTokenByGID(AppIdentifier appIdentifier, String gid) throws StorageQueryException { try { @@ -3187,6 +3704,15 @@ public boolean revokeOAuthTokenByClientId(AppIdentifier appIdentifier, String cl } } + @Override + public List deleteBulkImportUsers(AppIdentifier appIdentifier, @Nonnull String[] bulkImportUserIds) throws StorageQueryException { + try { + return BulkImportQueries.deleteBulkImportUsers(this, appIdentifier, bulkImportUserIds); + } catch (SQLException e) { + throw new StorageQueryException(e); + } + } + @Override public boolean revokeOAuthTokenByJTI(AppIdentifier appIdentifier, String gid, String jti) throws StorageQueryException { @@ -3197,6 +3723,24 @@ public boolean revokeOAuthTokenByJTI(AppIdentifier appIdentifier, String gid, St } } + @Override + public List getBulkImportUsersAndChangeStatusToProcessing(AppIdentifier appIdentifier, @Nonnull Integer limit) throws StorageQueryException { + try { + return BulkImportQueries.getBulkImportUsersAndChangeStatusToProcessing(this, appIdentifier, limit); + } catch (StorageTransactionLogicException e) { + throw new StorageQueryException(e.actualException); + } + } + + @Override + public void updateBulkImportUserPrimaryUserId(AppIdentifier appIdentifier, @Nonnull String bulkImportUserId, @Nonnull String primaryUserId) throws StorageQueryException { + try { + BulkImportQueries.updateBulkImportUserPrimaryUserId(this, appIdentifier, bulkImportUserId, primaryUserId); + } catch (SQLException e) { + throw new StorageQueryException(e); + } + } + @Override public boolean revokeOAuthTokenBySessionHandle(AppIdentifier appIdentifier, String sessionHandle) throws StorageQueryException { @@ -3207,6 +3751,15 @@ public boolean revokeOAuthTokenBySessionHandle(AppIdentifier appIdentifier, Stri } } + @Override + public long getBulkImportUsersCount(AppIdentifier appIdentifier, @Nullable BULK_IMPORT_USER_STATUS status) throws StorageQueryException { + try { + return BulkImportQueries.getBulkImportUsersCount(this, appIdentifier, status); + } catch (SQLException e) { + throw new StorageQueryException(e); + } + } + @Override public void addOAuthM2MTokenForStats(AppIdentifier appIdentifier, String clientId, long iat, long exp) throws StorageQueryException, OAuthClientNotFoundException { diff --git a/src/main/java/io/supertokens/storage/postgresql/config/PostgreSQLConfig.java b/src/main/java/io/supertokens/storage/postgresql/config/PostgreSQLConfig.java index 89aad8b7..6e915ca8 100644 --- a/src/main/java/io/supertokens/storage/postgresql/config/PostgreSQLConfig.java +++ b/src/main/java/io/supertokens/storage/postgresql/config/PostgreSQLConfig.java @@ -462,6 +462,10 @@ public String getOAuthLogoutChallengesTable() { return addSchemaAndPrefixToTableName("oauth_logout_challenges"); } + public String getBulkImportUsersTable() { + return addSchemaAndPrefixToTableName("bulk_import_users"); + } + private String addSchemaAndPrefixToTableName(String tableName) { return addSchemaToTableName(postgresql_table_names_prefix + tableName); } diff --git a/src/main/java/io/supertokens/storage/postgresql/queries/ActiveUsersQueries.java b/src/main/java/io/supertokens/storage/postgresql/queries/ActiveUsersQueries.java index ccd589ac..50edb15c 100644 --- a/src/main/java/io/supertokens/storage/postgresql/queries/ActiveUsersQueries.java +++ b/src/main/java/io/supertokens/storage/postgresql/queries/ActiveUsersQueries.java @@ -5,11 +5,13 @@ import io.supertokens.storage.postgresql.Start; import io.supertokens.storage.postgresql.config.Config; import io.supertokens.storage.postgresql.utils.Utils; +import org.jetbrains.annotations.TestOnly; import java.sql.Connection; import java.sql.SQLException; - -import org.jetbrains.annotations.TestOnly; +import java.util.HashMap; +import java.util.List; +import java.util.Map; import static io.supertokens.storage.postgresql.QueryExecutorTemplate.execute; import static io.supertokens.storage.postgresql.QueryExecutorTemplate.update; @@ -133,6 +135,30 @@ public static Long getLastActiveByUserId(Start start, AppIdentifier appIdentifie } } + public static Map getLastActiveByMultipleUserIds(Start start, AppIdentifier appIdentifier, List userIds) + throws StorageQueryException { + String QUERY = "SELECT user_id, last_active_time FROM " + Config.getConfig(start).getUserLastActiveTable() + + " WHERE app_id = ? AND user_id IN ( " + Utils.generateCommaSeperatedQuestionMarks(userIds.size())+ " )"; + + try { + return execute(start, QUERY, pst -> { + pst.setString(1, appIdentifier.getAppId()); + for (int i = 0; i < userIds.size(); i++) { + pst.setString(2+i, userIds.get(i)); + } + }, res -> { + Map lastActiveByUserIds = new HashMap<>(); + if (res.next()) { + String userId = res.getString("user_id"); + lastActiveByUserIds.put(userId, res.getLong("last_active_time")); + } + return lastActiveByUserIds; + }); + } catch (SQLException e) { + throw new StorageQueryException(e); + } + } + public static void deleteUserActive_Transaction(Connection con, Start start, AppIdentifier appIdentifier, String userId) throws StorageQueryException, SQLException { diff --git a/src/main/java/io/supertokens/storage/postgresql/queries/BulkImportQueries.java b/src/main/java/io/supertokens/storage/postgresql/queries/BulkImportQueries.java new file mode 100644 index 00000000..f4874540 --- /dev/null +++ b/src/main/java/io/supertokens/storage/postgresql/queries/BulkImportQueries.java @@ -0,0 +1,347 @@ +/* + * Copyright (c) 2024, VRAI Labs and/or its affiliates. All rights reserved. + * + * This software is licensed under the Apache License, Version 2.0 (the + * "License") as published by the Apache Software Foundation. + * + * You may not use this file except in compliance with the License. You may + * obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package io.supertokens.storage.postgresql.queries; + +import io.supertokens.pluginInterface.RowMapper; +import io.supertokens.pluginInterface.bulkimport.BulkImportStorage.BULK_IMPORT_USER_STATUS; +import io.supertokens.pluginInterface.bulkimport.BulkImportUser; +import io.supertokens.pluginInterface.exceptions.StorageQueryException; +import io.supertokens.pluginInterface.exceptions.StorageTransactionLogicException; +import io.supertokens.pluginInterface.multitenancy.AppIdentifier; +import io.supertokens.storage.postgresql.Start; +import io.supertokens.storage.postgresql.config.Config; +import io.supertokens.storage.postgresql.utils.Utils; + +import javax.annotation.Nonnull; +import javax.annotation.Nullable; +import java.sql.Connection; +import java.sql.PreparedStatement; +import java.sql.ResultSet; +import java.sql.SQLException; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; + +import static io.supertokens.storage.postgresql.QueryExecutorTemplate.execute; +import static io.supertokens.storage.postgresql.QueryExecutorTemplate.update; + +public class BulkImportQueries { + static String getQueryToCreateBulkImportUsersTable(Start start) { + String schema = Config.getConfig(start).getTableSchema(); + String tableName = Config.getConfig(start).getBulkImportUsersTable(); + return "CREATE TABLE IF NOT EXISTS " + tableName + " (" + + "id CHAR(36)," + + "app_id VARCHAR(64) NOT NULL DEFAULT 'public'," + + "primary_user_id VARCHAR(36)," + + "raw_data TEXT NOT NULL," + + "status VARCHAR(128) DEFAULT 'NEW'," + + "error_msg TEXT," + + "created_at BIGINT NOT NULL, " + + "updated_at BIGINT NOT NULL, " + + "CONSTRAINT " + Utils.getConstraintName(schema, tableName, null, "pkey") + + " PRIMARY KEY(app_id, id)," + + "CONSTRAINT " + Utils.getConstraintName(schema, tableName, "app_id", "fkey") + " " + + "FOREIGN KEY(app_id) " + + "REFERENCES " + Config.getConfig(start).getAppsTable() + " (app_id) ON DELETE CASCADE" + + " );"; + } + + public static String getQueryToCreateStatusUpdatedAtIndex(Start start) { + return "CREATE INDEX IF NOT EXISTS bulk_import_users_status_updated_at_index ON " + + Config.getConfig(start).getBulkImportUsersTable() + " (app_id, status, updated_at)"; + } + + public static String getQueryToCreatePaginationIndex1(Start start) { + return "CREATE INDEX IF NOT EXISTS bulk_import_users_pagination_index1 ON " + + Config.getConfig(start).getBulkImportUsersTable() + " (app_id, status, created_at DESC, id DESC)"; + } + + public static String getQueryToCreatePaginationIndex2(Start start) { + return "CREATE INDEX IF NOT EXISTS bulk_import_users_pagination_index2 ON " + + Config.getConfig(start).getBulkImportUsersTable() + " (app_id, created_at DESC, id DESC)"; + } + + public static void insertBulkImportUsers(Start start, AppIdentifier appIdentifier, List users) + throws SQLException, StorageQueryException { + StringBuilder queryBuilder = new StringBuilder( + "INSERT INTO " + Config.getConfig(start).getBulkImportUsersTable() + " (id, app_id, raw_data, created_at, updated_at) VALUES "); + + int userCount = users.size(); + + for (int i = 0; i < userCount; i++) { + queryBuilder.append(" (?, ?, ?, ?, ?)"); + + if (i < userCount - 1) { + queryBuilder.append(","); + } + } + + update(start, queryBuilder.toString(), pst -> { + int parameterIndex = 1; + for (BulkImportUser user : users) { + pst.setString(parameterIndex++, user.id); + pst.setString(parameterIndex++, appIdentifier.getAppId()); + pst.setString(parameterIndex++, user.toRawDataForDbStorage()); + pst.setLong(parameterIndex++, System.currentTimeMillis()); + pst.setLong(parameterIndex++, System.currentTimeMillis()); + } + }); + } + + public static void updateBulkImportUserStatus_Transaction(Start start, Connection con, AppIdentifier appIdentifier, + @Nonnull String bulkImportUserId, @Nonnull BULK_IMPORT_USER_STATUS status, @Nullable String errorMessage) + throws SQLException { + String query = "UPDATE " + Config.getConfig(start).getBulkImportUsersTable() + + " SET status = ?, error_msg = ?, updated_at = ? WHERE app_id = ? and id = ?"; + + List parameters = new ArrayList<>(); + + parameters.add(status.toString()); + parameters.add(errorMessage); + parameters.add(System.currentTimeMillis()); + parameters.add(appIdentifier.getAppId()); + parameters.add(bulkImportUserId); + + update(con, query, pst -> { + for (int i = 0; i < parameters.size(); i++) { + pst.setObject(i + 1, parameters.get(i)); + } + }); + } + + public static void updateMultipleBulkImportUsersStatusToError_Transaction(Start start, Connection con, AppIdentifier appIdentifier, + @Nonnull Map bulkImportUserIdToErrorMessage) + throws SQLException { + BULK_IMPORT_USER_STATUS errorStatus = BULK_IMPORT_USER_STATUS.FAILED; + String query = "UPDATE " + Config.getConfig(start).getBulkImportUsersTable() + + " SET status = ?, error_msg = ?, updated_at = ? WHERE app_id = ? and id = ?"; + + PreparedStatement setErrorStatement = con.prepareStatement(query); + + int counter = 0; + for(String bulkImportUserId : bulkImportUserIdToErrorMessage.keySet()){ + setErrorStatement.setString(1, errorStatus.toString()); + setErrorStatement.setString(2, bulkImportUserIdToErrorMessage.get(bulkImportUserId)); + setErrorStatement.setLong(3, System.currentTimeMillis()); + setErrorStatement.setString(4, appIdentifier.getAppId()); + setErrorStatement.setString(5, bulkImportUserId); + setErrorStatement.addBatch(); + + if(counter % 100 == 0) { + setErrorStatement.executeBatch(); + } + } + + setErrorStatement.executeBatch(); + } + + public static List getBulkImportUsersAndChangeStatusToProcessing(Start start, + AppIdentifier appIdentifier, + @Nonnull Integer limit) + throws StorageQueryException, StorageTransactionLogicException { + + return start.startTransaction(con -> { + Connection sqlCon = (Connection) con.getConnection(); + try { + + // "FOR UPDATE" ensures that multiple cron jobs don't read the same rows simultaneously. + // If one process locks the first 1000 rows, others will wait for the lock to be released. + // "SKIP LOCKED" allows other processes to skip locked rows and select the next 1000 available rows. + String selectQuery = "SELECT * FROM " + Config.getConfig(start).getBulkImportUsersTable() + + " WHERE app_id = ?" + //+ " AND (status = 'NEW' OR (status = 'PROCESSING' AND updated_at < (EXTRACT(EPOCH FROM CURRENT_TIMESTAMP) * 1000) - 10 * 60 * 1000))" /* 10 mins */ + + " AND (status = 'NEW' OR status = 'PROCESSING' )" + + " LIMIT ? FOR UPDATE SKIP LOCKED"; + + List bulkImportUsers = new ArrayList<>(); + + execute(sqlCon, selectQuery, pst -> { + pst.setString(1, appIdentifier.getAppId()); + pst.setInt(2, limit); + }, result -> { + while (result.next()) { + bulkImportUsers.add(BulkImportUserRowMapper.getInstance().mapOrThrow(result)); + } + return null; + }); + + if (bulkImportUsers.isEmpty()) { + return new ArrayList<>(); + } + + String updateQuery = "UPDATE " + Config.getConfig(start).getBulkImportUsersTable() + + " SET status = ?, updated_at = ? WHERE app_id = ? AND id IN (" + Utils + .generateCommaSeperatedQuestionMarks(bulkImportUsers.size()) + ")"; + + update(sqlCon, updateQuery, pst -> { + int index = 1; + pst.setString(index++, BULK_IMPORT_USER_STATUS.PROCESSING.toString()); + pst.setLong(index++, System.currentTimeMillis()); + pst.setString(index++, appIdentifier.getAppId()); + for (BulkImportUser user : bulkImportUsers) { + pst.setObject(index++, user.id); + } + }); + return bulkImportUsers; + } catch (SQLException throwables) { + throw new StorageTransactionLogicException(throwables); + } + }); + } + + public static List getBulkImportUsers(Start start, AppIdentifier appIdentifier, + @Nonnull Integer limit, @Nullable BULK_IMPORT_USER_STATUS status, + @Nullable String bulkImportUserId, @Nullable Long createdAt) + throws SQLException, StorageQueryException { + + String baseQuery = "SELECT * FROM " + Config.getConfig(start).getBulkImportUsersTable(); + + StringBuilder queryBuilder = new StringBuilder(baseQuery); + List parameters = new ArrayList<>(); + + queryBuilder.append(" WHERE app_id = ?"); + parameters.add(appIdentifier.getAppId()); + + if (status != null) { + queryBuilder.append(" AND status = ?"); + parameters.add(status.toString()); + } + + if (bulkImportUserId != null && createdAt != null) { + queryBuilder + .append(" AND (created_at < ? OR (created_at = ? AND id <= ?))"); + parameters.add(createdAt); + parameters.add(createdAt); + parameters.add(bulkImportUserId); + } + + queryBuilder.append(" ORDER BY created_at DESC, id DESC LIMIT ?"); + parameters.add(limit); + + String query = queryBuilder.toString(); + + return execute(start, query, pst -> { + for (int i = 0; i < parameters.size(); i++) { + pst.setObject(i + 1, parameters.get(i)); + } + }, result -> { + List bulkImportUsers = new ArrayList<>(); + while (result.next()) { + bulkImportUsers.add(BulkImportUserRowMapper.getInstance().mapOrThrow(result)); + } + return bulkImportUsers; + }); + } + + public static List deleteBulkImportUsers(Start start, AppIdentifier appIdentifier, + @Nonnull String[] bulkImportUserIds) throws SQLException, StorageQueryException { + if (bulkImportUserIds.length == 0) { + return new ArrayList<>(); + } + + String baseQuery = "DELETE FROM " + Config.getConfig(start).getBulkImportUsersTable(); + StringBuilder queryBuilder = new StringBuilder(baseQuery); + + List parameters = new ArrayList<>(); + + queryBuilder.append(" WHERE app_id = ?"); + parameters.add(appIdentifier.getAppId()); + + queryBuilder.append(" AND id IN ("); + for (int i = 0; i < bulkImportUserIds.length; i++) { + if (i != 0) { + queryBuilder.append(", "); + } + queryBuilder.append("?"); + parameters.add(bulkImportUserIds[i]); + } + queryBuilder.append(") RETURNING id"); + + String query = queryBuilder.toString(); + + return update(start, query, pst -> { + for (int i = 0; i < parameters.size(); i++) { + pst.setObject(i + 1, parameters.get(i)); + } + }, result -> { + List deletedIds = new ArrayList<>(); + while (result.next()) { + deletedIds.add(result.getString("id")); + } + return deletedIds; + }); + } + + public static void updateBulkImportUserPrimaryUserId(Start start, AppIdentifier appIdentifier, + @Nonnull String bulkImportUserId, + @Nonnull String primaryUserId) throws SQLException, StorageQueryException { + String query = "UPDATE " + Config.getConfig(start).getBulkImportUsersTable() + + " SET primary_user_id = ?, updated_at = ? WHERE app_id = ? and id = ?"; + + update(start, query, pst -> { + pst.setString(1, primaryUserId); + pst.setLong(2, System.currentTimeMillis()); + pst.setString(3, appIdentifier.getAppId()); + pst.setString(4, bulkImportUserId); + }); + } + + public static long getBulkImportUsersCount(Start start, AppIdentifier appIdentifier, @Nullable BULK_IMPORT_USER_STATUS status) throws SQLException, StorageQueryException { + String baseQuery = "SELECT COUNT(*) FROM " + Config.getConfig(start).getBulkImportUsersTable(); + StringBuilder queryBuilder = new StringBuilder(baseQuery); + + List parameters = new ArrayList<>(); + + queryBuilder.append(" WHERE app_id = ?"); + parameters.add(appIdentifier.getAppId()); + + if (status != null) { + queryBuilder.append(" AND status = ?"); + parameters.add(status.toString()); + } + + String query = queryBuilder.toString(); + + return execute(start, query, pst -> { + for (int i = 0; i < parameters.size(); i++) { + pst.setObject(i + 1, parameters.get(i)); + } + }, result -> { + result.next(); + return result.getLong(1); + }); + } + + private static class BulkImportUserRowMapper implements RowMapper { + private static final BulkImportUserRowMapper INSTANCE = new BulkImportUserRowMapper(); + + private BulkImportUserRowMapper() { + } + + private static BulkImportUserRowMapper getInstance() { + return INSTANCE; + } + + @Override + public BulkImportUser map(ResultSet result) throws Exception { + return BulkImportUser.fromRawDataFromDbStorage(result.getString("id"), result.getString("raw_data"), + BULK_IMPORT_USER_STATUS.valueOf(result.getString("status")), + result.getString("primary_user_id"), result.getString("error_msg"), result.getLong("created_at"), + result.getLong("updated_at")); + } + } +} diff --git a/src/main/java/io/supertokens/storage/postgresql/queries/EmailPasswordQueries.java b/src/main/java/io/supertokens/storage/postgresql/queries/EmailPasswordQueries.java index efed6c7f..deb2bfe8 100644 --- a/src/main/java/io/supertokens/storage/postgresql/queries/EmailPasswordQueries.java +++ b/src/main/java/io/supertokens/storage/postgresql/queries/EmailPasswordQueries.java @@ -19,19 +19,19 @@ import io.supertokens.pluginInterface.RowMapper; import io.supertokens.pluginInterface.authRecipe.AuthRecipeUserInfo; import io.supertokens.pluginInterface.authRecipe.LoginMethod; +import io.supertokens.pluginInterface.emailpassword.EmailPasswordImportUser; import io.supertokens.pluginInterface.emailpassword.PasswordResetTokenInfo; -import io.supertokens.pluginInterface.emailpassword.exceptions.DuplicateEmailException; import io.supertokens.pluginInterface.emailpassword.exceptions.UnknownUserIdException; import io.supertokens.pluginInterface.exceptions.StorageQueryException; import io.supertokens.pluginInterface.exceptions.StorageTransactionLogicException; import io.supertokens.pluginInterface.multitenancy.AppIdentifier; import io.supertokens.pluginInterface.multitenancy.TenantIdentifier; -import io.supertokens.storage.postgresql.ConnectionPool; import io.supertokens.storage.postgresql.Start; import io.supertokens.storage.postgresql.config.Config; import io.supertokens.storage.postgresql.utils.Utils; import java.sql.Connection; +import java.sql.PreparedStatement; import java.sql.ResultSet; import java.sql.SQLException; import java.util.*; @@ -336,6 +336,83 @@ public static AuthRecipeUserInfo signUp(Start start, TenantIdentifier tenantIden }); } + public static void signUpMultipleForBulkImport_Transaction(Start start, Connection sqlCon, List usersToSignUp) + throws StorageQueryException, StorageTransactionLogicException, SQLException { + try { + String app_id_to_user_id_QUERY = "INSERT INTO " + getConfig(start).getAppIdToUserIdTable() + + "(app_id, user_id, primary_or_recipe_user_id, recipe_id)" + " VALUES(?, ?, ?, ?)"; + + String all_auth_recipe_users_QUERY = "INSERT INTO " + getConfig(start).getUsersTable() + + "(app_id, tenant_id, user_id, primary_or_recipe_user_id, recipe_id, time_joined, " + + "primary_or_recipe_user_time_joined)" + + " VALUES(?, ?, ?, ?, ?, ?, ?)"; + + String emailpassword_users_QUERY = "INSERT INTO " + getConfig(start).getEmailPasswordUsersTable() + + "(app_id, user_id, email, password_hash, time_joined)" + " VALUES(?, ?, ?, ?, ?)"; + + String emailpassword_users_to_tenant_QUERY = + "INSERT INTO " + getConfig(start).getEmailPasswordUserToTenantTable() + + "(app_id, tenant_id, user_id, email)" + " VALUES(?, ?, ?, ?)"; + + PreparedStatement appIdToUserId = sqlCon.prepareStatement(app_id_to_user_id_QUERY); + PreparedStatement allAuthRecipeUsers = sqlCon.prepareStatement(all_auth_recipe_users_QUERY); + PreparedStatement emailPasswordUsers = sqlCon.prepareStatement(emailpassword_users_QUERY); + PreparedStatement emailPasswordUsersToTenant = sqlCon.prepareStatement(emailpassword_users_to_tenant_QUERY); + + int counter = 0; + for (EmailPasswordImportUser user : usersToSignUp) { + String userId = user.userId; + TenantIdentifier tenantIdentifier = user.tenantIdentifier; + + appIdToUserId.setString(1, tenantIdentifier.getAppId()); + appIdToUserId.setString(2, userId); + appIdToUserId.setString(3, userId); + appIdToUserId.setString(4, EMAIL_PASSWORD.toString()); + appIdToUserId.addBatch(); + + + allAuthRecipeUsers.setString(1, tenantIdentifier.getAppId()); + allAuthRecipeUsers.setString(2, tenantIdentifier.getTenantId()); + allAuthRecipeUsers.setString(3, userId); + allAuthRecipeUsers.setString(4, userId); + allAuthRecipeUsers.setString(5, EMAIL_PASSWORD.toString()); + allAuthRecipeUsers.setLong(6, user.timeJoinedMSSinceEpoch); + allAuthRecipeUsers.setLong(7, user.timeJoinedMSSinceEpoch); + allAuthRecipeUsers.addBatch(); + + emailPasswordUsers.setString(1, tenantIdentifier.getAppId()); + emailPasswordUsers.setString(2, userId); + emailPasswordUsers.setString(3, user.email); + emailPasswordUsers.setString(4, user.passwordHash); + emailPasswordUsers.setLong(5, user.timeJoinedMSSinceEpoch); + emailPasswordUsers.addBatch(); + + emailPasswordUsersToTenant.setString(1, tenantIdentifier.getAppId()); + emailPasswordUsersToTenant.setString(2, tenantIdentifier.getTenantId()); + emailPasswordUsersToTenant.setString(3, userId); + emailPasswordUsersToTenant.setString(4, user.email); + emailPasswordUsersToTenant.addBatch(); + counter++; + if (counter % 100 == 0) { + appIdToUserId.executeBatch(); + allAuthRecipeUsers.executeBatch(); + emailPasswordUsers.executeBatch(); + emailPasswordUsersToTenant.executeBatch(); + } + } + + //execute the remaining ones + appIdToUserId.executeBatch(); + allAuthRecipeUsers.executeBatch(); + emailPasswordUsers.executeBatch(); + emailPasswordUsersToTenant.executeBatch(); + + sqlCon.commit(); + } catch (SQLException throwables) { + throw new StorageTransactionLogicException(throwables); + } + } + public static void deleteUser_Transaction(Connection sqlCon, Start start, AppIdentifier appIdentifier, String userId, boolean deleteUserIdMappingToo) throws StorageQueryException, SQLException { @@ -479,6 +556,30 @@ public static String lockEmail_Transaction(Start start, Connection con, }); } + public static List lockEmail_Transaction(Start start, Connection con, + AppIdentifier appIdentifier, + List emails) + throws StorageQueryException, SQLException { + if(emails == null || emails.isEmpty()){ + return new ArrayList<>(); + } + String QUERY = "SELECT user_id FROM " + getConfig(start).getEmailPasswordUsersTable() + + " WHERE app_id = ? AND email IN (" + Utils.generateCommaSeperatedQuestionMarks(emails.size()) + ") FOR UPDATE"; + + return execute(con, QUERY, pst -> { + pst.setString(1, appIdentifier.getAppId()); + for (int i = 0; i < emails.size(); i++) { + pst.setString(2 + i, emails.get(i)); + } + }, result -> { + List results = new ArrayList<>(); + while (result.next()) { + results.add(result.getString("user_id")); + } + return results; + }); + } + public static String getPrimaryUserIdUsingEmail(Start start, TenantIdentifier tenantIdentifier, String email) throws StorageQueryException, SQLException { @@ -522,6 +623,33 @@ public static List getPrimaryUserIdsUsingEmail_Transaction(Start start, }); } + public static List getPrimaryUserIdsUsingMultipleEmails_Transaction(Start start, Connection con, + AppIdentifier appIdentifier, + List emails) + throws StorageQueryException, SQLException { + if(emails.isEmpty()){ + return new ArrayList<>(); + } + String QUERY = "SELECT DISTINCT all_users.primary_or_recipe_user_id AS user_id " + + "FROM " + getConfig(start).getEmailPasswordUsersTable() + " AS ep" + + " JOIN " + getConfig(start).getAppIdToUserIdTable() + " AS all_users" + + " ON ep.app_id = all_users.app_id AND ep.user_id = all_users.user_id" + + " WHERE ep.app_id = ? AND ep.email IN ( " + Utils.generateCommaSeperatedQuestionMarks(emails.size()) + " )"; + + return execute(con, QUERY, pst -> { + pst.setString(1, appIdentifier.getAppId()); + for (int i = 0; i < emails.size(); i++) { + pst.setString(2+i, emails.get(i)); + } + }, result -> { + List userIds = new ArrayList<>(); + while (result.next()) { + userIds.add(result.getString("user_id")); + } + return userIds; + }); + } + public static boolean addUserIdToTenant_Transaction(Start start, Connection sqlCon, TenantIdentifier tenantIdentifier, String userId) throws SQLException, StorageQueryException, UnknownUserIdException { diff --git a/src/main/java/io/supertokens/storage/postgresql/queries/EmailVerificationQueries.java b/src/main/java/io/supertokens/storage/postgresql/queries/EmailVerificationQueries.java index 9c70cf8f..fc3fa372 100644 --- a/src/main/java/io/supertokens/storage/postgresql/queries/EmailVerificationQueries.java +++ b/src/main/java/io/supertokens/storage/postgresql/queries/EmailVerificationQueries.java @@ -28,6 +28,7 @@ import io.supertokens.storage.postgresql.utils.Utils; import java.sql.Connection; +import java.sql.PreparedStatement; import java.sql.ResultSet; import java.sql.SQLException; import java.util.*; @@ -124,6 +125,48 @@ public static void updateUsersIsEmailVerified_Transaction(Start start, Connectio } } + public static void updateMultipleUsersIsEmailVerified_Transaction(Start start, Connection con, AppIdentifier appIdentifier, + Map emailToUserIds, + boolean isEmailVerified) + throws SQLException, StorageQueryException { + + if (isEmailVerified) { + String QUERY = "INSERT INTO " + getConfig(start).getEmailVerificationTable() + + "(app_id, user_id, email) VALUES(?, ?, ?)"; + PreparedStatement insertQuery = con.prepareStatement(QUERY); + int counter = 0; + for(Map.Entry emailToUser : emailToUserIds.entrySet()){ + insertQuery.setString(1, appIdentifier.getAppId()); + insertQuery.setString(2, emailToUser.getKey()); + insertQuery.setString(3, emailToUser.getValue()); + insertQuery.addBatch(); + + counter++; + if (counter % 100 == 0) { + insertQuery.executeBatch(); + } + } + insertQuery.executeBatch(); + } else { + String QUERY = "DELETE FROM " + getConfig(start).getEmailVerificationTable() + + " WHERE app_id = ? AND user_id = ? AND email = ?"; + PreparedStatement deleteQuery = con.prepareStatement(QUERY); + int counter = 0; + for (Map.Entry emailToUser : emailToUserIds.entrySet()) { + deleteQuery.setString(1, appIdentifier.getAppId()); + deleteQuery.setString(2, emailToUser.getValue()); + deleteQuery.setString(3, emailToUser.getKey()); + deleteQuery.addBatch(); + + counter++; + if (counter % 100 == 0) { + deleteQuery.executeBatch(); + } + } + deleteQuery.executeBatch(); + } + } + public static void deleteAllEmailVerificationTokensForUser_Transaction(Start start, Connection con, TenantIdentifier tenantIdentifier, String userId, @@ -481,6 +524,45 @@ public static boolean isUserIdBeingUsedForEmailVerification(Start start, AppIden } } + public static Set findUserIdsBeingUsedForEmailVerification(Start start, AppIdentifier appIdentifier, List userIds) + throws SQLException, StorageQueryException { + + Set foundUserIds = new HashSet<>(); + + String email_verificiation_tokens_QUERY = "SELECT * FROM " + getConfig(start).getEmailVerificationTokensTable() + + " WHERE app_id = ? AND user_id IN (" + Utils.generateCommaSeperatedQuestionMarks(userIds.size()) +")"; + + foundUserIds.addAll(execute(start, email_verificiation_tokens_QUERY, pst -> { + pst.setString(1, appIdentifier.getAppId()); + for (int i = 0; i < userIds.size(); i++) { + pst.setString(2 + i, userIds.get(i)); + } + }, result -> { + Set userIdsFound = new HashSet<>(); + while (result.next()) { + userIdsFound.add(result.getString("user_id")); + } + return userIdsFound; + })); + + String email_verification_table_QUERY = "SELECT * FROM " + getConfig(start).getEmailVerificationTable() + + " WHERE app_id = ? AND user_id IN (" + Utils.generateCommaSeperatedQuestionMarks(userIds.size()) +")"; + + foundUserIds.addAll(execute(start, email_verification_table_QUERY, pst -> { + pst.setString(1, appIdentifier.getAppId()); + for (int i = 0; i < userIds.size(); i++) { + pst.setString(2 + i, userIds.get(i)); + } + }, result -> { + Set userIdsFound = new HashSet<>(); + while (result.next()) { + userIdsFound.add(result.getString("user_id")); + } + return userIdsFound; + })); + return foundUserIds; + } + public static void updateIsEmailVerifiedToExternalUserId(Start start, AppIdentifier appIdentifier, String supertokensUserId, String externalUserId) throws StorageQueryException { @@ -517,6 +599,52 @@ public static void updateIsEmailVerifiedToExternalUserId(Start start, AppIdentif } } + public static void updateMultipleIsEmailVerifiedToExternalUserIds(Start start, AppIdentifier appIdentifier, + Map supertokensUserIdToExternalUserId) + throws StorageQueryException { + try { + start.startTransaction((TransactionConnection con) -> { + Connection sqlCon = (Connection) con.getConnection(); + try { + String update_email_verification_table_query = "UPDATE " + getConfig(start).getEmailVerificationTable() + + " SET user_id = ? WHERE app_id = ? AND user_id = ?"; + String update_email_verification_tokens_table_query = "UPDATE " + getConfig(start).getEmailVerificationTokensTable() + + " SET user_id = ? WHERE app_id = ? AND user_id = ?"; + PreparedStatement updateEmailVerificationQuery = sqlCon.prepareStatement(update_email_verification_table_query); + PreparedStatement updateEmailVerificationTokensQuery = sqlCon.prepareStatement(update_email_verification_tokens_table_query); + + int counter = 0; + for (String supertokensUserId : supertokensUserIdToExternalUserId.keySet()){ + updateEmailVerificationQuery.setString(1, supertokensUserIdToExternalUserId.get(supertokensUserId)); + updateEmailVerificationQuery.setString(2, appIdentifier.getAppId()); + updateEmailVerificationQuery.setString(3, supertokensUserId); + updateEmailVerificationQuery.addBatch(); + + updateEmailVerificationTokensQuery.setString(1, supertokensUserIdToExternalUserId.get(supertokensUserId)); + updateEmailVerificationTokensQuery.setString(2, appIdentifier.getAppId()); + updateEmailVerificationTokensQuery.setString(3, supertokensUserId); + updateEmailVerificationTokensQuery.addBatch(); + + counter++; + if(counter % 100 == 0) { + updateEmailVerificationQuery.executeBatch(); + updateEmailVerificationTokensQuery.executeBatch(); + } + } + updateEmailVerificationQuery.executeBatch(); + updateEmailVerificationTokensQuery.executeBatch(); + + } catch (SQLException e) { + throw new StorageTransactionLogicException(e); + } + + return null; + }); + } catch (StorageTransactionLogicException e) { + throw new StorageQueryException(e.actualException); + } + } + private static class EmailVerificationTokenInfoRowMapper implements RowMapper { private static final EmailVerificationTokenInfoRowMapper INSTANCE = new EmailVerificationTokenInfoRowMapper(); diff --git a/src/main/java/io/supertokens/storage/postgresql/queries/GeneralQueries.java b/src/main/java/io/supertokens/storage/postgresql/queries/GeneralQueries.java index 83d8bc0e..b0c5c7ee 100644 --- a/src/main/java/io/supertokens/storage/postgresql/queries/GeneralQueries.java +++ b/src/main/java/io/supertokens/storage/postgresql/queries/GeneralQueries.java @@ -34,6 +34,7 @@ import org.jetbrains.annotations.TestOnly; import java.sql.Connection; +import java.sql.PreparedStatement; import java.sql.ResultSet; import java.sql.SQLException; import java.util.*; @@ -103,8 +104,16 @@ static String getQueryToCreateUsersTable(Start start) { public static String getQueryToCreateUserIdIndexForUsersTable(Start start) { return "CREATE INDEX IF NOT EXISTS all_auth_recipe_user_id_index ON " + + Config.getConfig(start).getUsersTable() + "(user_id);"; + } + public static String getQueryToCreateUserIdAppIdIndexForUsersTable(Start start) { + return "CREATE INDEX IF NOT EXISTS all_auth_recipe_user_id_app_id_index ON " + Config.getConfig(start).getUsersTable() + "(app_id, user_id);"; } + public static String getQueryToCreateAppIdIndexForUsersTable(Start start) { + return "CREATE INDEX IF NOT EXISTS all_auth_recipe_user_app_id_index ON " + + Config.getConfig(start).getUsersTable() + "(app_id);"; + } public static String getQueryToCreateTenantIdIndexForUsersTable(Start start) { return "CREATE INDEX IF NOT EXISTS all_auth_recipe_tenant_id_index ON " @@ -247,6 +256,11 @@ static String getQueryToCreatePrimaryUserIdIndexForAppIdToUserIdTable(Start star + Config.getConfig(start).getAppIdToUserIdTable() + "(primary_or_recipe_user_id, app_id);"; } + static String getQueryToCreateUserIdIndexForAppIdToUserIdTable(Start start) { + return "CREATE INDEX IF NOT EXISTS app_id_to_user_id_user_id_index ON " + + Config.getConfig(start).getAppIdToUserIdTable() + "(user_id, app_id);"; + } + public static void createTablesIfNotExists(Start start, Connection con) throws SQLException, StorageQueryException { int numberOfRetries = 0; boolean retry = true; @@ -281,6 +295,7 @@ public static void createTablesIfNotExists(Start start, Connection con) throws S // index update(con, getQueryToCreateAppIdIndexForAppIdToUserIdTable(start), NO_OP_SETTER); update(con, getQueryToCreatePrimaryUserIdIndexForAppIdToUserIdTable(start), NO_OP_SETTER); + update(con, getQueryToCreateUserIdIndexForAppIdToUserIdTable(start), NO_OP_SETTER); } if (!doesTableExists(start, con, Config.getConfig(start).getUsersTable())) { @@ -433,6 +448,8 @@ public static void createTablesIfNotExists(Start start, Connection con) throws S // index update(con, getQueryToCreateUserIdIndexForUsersTable(start), NO_OP_SETTER); + update(con, getQueryToCreateUserIdAppIdIndexForUsersTable(start), NO_OP_SETTER); + update(con, getQueryToCreateAppIdIndexForUsersTable(start), NO_OP_SETTER); update(con, getQueryToCreateTenantIdIndexForUsersTable(start), NO_OP_SETTER); } @@ -554,6 +571,15 @@ public static void createTablesIfNotExists(Start start, Connection con) throws S update(con, TOTPQueries.getQueryToCreateTenantIdIndexForUsedCodesTable(start), NO_OP_SETTER); } + if (!doesTableExists(start, con, Config.getConfig(start).getBulkImportUsersTable())) { + getInstance(start).addState(CREATING_NEW_TABLE, null); + update(start, BulkImportQueries.getQueryToCreateBulkImportUsersTable(start), NO_OP_SETTER); + // index: + update(start, BulkImportQueries.getQueryToCreateStatusUpdatedAtIndex(start), NO_OP_SETTER); + update(start, BulkImportQueries.getQueryToCreatePaginationIndex1(start), NO_OP_SETTER); + update(start, BulkImportQueries.getQueryToCreatePaginationIndex2(start), NO_OP_SETTER); + } + if (!doesTableExists(start, con, Config.getConfig(start).getOAuthClientsTable())) { getInstance(start).addState(CREATING_NEW_TABLE, null); update(con, OAuthQueries.getQueryToCreateOAuthClientTable(start), NO_OP_SETTER); @@ -620,7 +646,18 @@ public static void deleteAllTables(Start start) throws SQLException, StorageQuer String DROP_QUERY = "DROP INDEX IF EXISTS all_auth_recipe_users_pagination_index"; update(start, DROP_QUERY, NO_OP_SETTER); } - + { + String DROP_QUERY = "DROP INDEX IF EXISTS bulk_import_users_status_updated_at_index"; + update(start, DROP_QUERY, NO_OP_SETTER); + } + { + String DROP_QUERY = "DROP INDEX IF EXISTS bulk_import_users_pagination_index1"; + update(start, DROP_QUERY, NO_OP_SETTER); + } + { + String DROP_QUERY = "DROP INDEX IF EXISTS bulk_import_users_pagination_index2"; + update(start, DROP_QUERY, NO_OP_SETTER); + } { String DROP_QUERY = "DROP TABLE IF EXISTS " + getConfig(start).getAppsTable() + "," @@ -661,10 +698,12 @@ public static void deleteAllTables(Start start) throws SQLException, StorageQuer + getConfig(start).getTotpUsedCodesTable() + "," + getConfig(start).getTotpUserDevicesTable() + "," + getConfig(start).getTotpUsersTable() + "," + + getConfig(start).getBulkImportUsersTable() + "," + getConfig(start).getOAuthClientsTable() + "," + getConfig(start).getOAuthSessionsTable() + "," + getConfig(start).getOAuthLogoutChallengesTable() + "," + getConfig(start).getOAuthM2MTokensTable(); + update(start, DROP_QUERY, NO_OP_SETTER); } } @@ -857,6 +896,24 @@ public static boolean doesUserIdExist(Start start, TenantIdentifier tenantIdenti }, ResultSet::next); } + public static List findUserIdsThatExist(Start start, AppIdentifier appIdentifier, List userIds) + throws SQLException, StorageQueryException { + String QUERY = "SELECT user_id FROM " + getConfig(start).getAppIdToUserIdTable() + + " WHERE app_id = ? AND user_id IN ("+ Utils.generateCommaSeperatedQuestionMarks(userIds.size()) +")"; + return execute(start, QUERY, pst -> { + pst.setString(1, appIdentifier.getAppId()); + for(int i = 0; i { + List foundUserIds = new ArrayList<>(); + while(result.next()){ + foundUserIds.add(result.getString(1)); + } + return foundUserIds; + }); + } + public static AuthRecipeUserInfo[] getUsers(Start start, TenantIdentifier tenantIdentifier, @NotNull Integer limit, @NotNull String timeJoinedOrder, @Nullable RECIPE_ID[] includeRecipeIds, @Nullable String userId, @@ -1182,6 +1239,37 @@ public static void makePrimaryUser_Transaction(Start start, Connection sqlCon, A } } + public static void makePrimaryUsers_Transaction(Start start, Connection sqlCon, AppIdentifier appIdentifier, + List userIds) + throws SQLException, StorageQueryException { + + String users_update_QUERY = "UPDATE " + getConfig(start).getUsersTable() + + " SET is_linked_or_is_a_primary_user = true WHERE app_id = ? AND user_id = ?"; + String appid_to_userid_update_QUERY = "UPDATE " + getConfig(start).getAppIdToUserIdTable() + + " SET is_linked_or_is_a_primary_user = true WHERE app_id = ? AND user_id = ?"; + + PreparedStatement usersUpdateStatement = sqlCon.prepareStatement(users_update_QUERY); + PreparedStatement appIdToUserIdUpdateStatement = sqlCon.prepareStatement(appid_to_userid_update_QUERY); + int counter = 0; + for(String userId: userIds){ + usersUpdateStatement.setString(1, appIdentifier.getAppId()); + usersUpdateStatement.setString(2, userId); + usersUpdateStatement.addBatch(); + + appIdToUserIdUpdateStatement.setString(1, appIdentifier.getAppId()); + appIdToUserIdUpdateStatement.setString(2, userId); + appIdToUserIdUpdateStatement.addBatch(); + + counter++; + if(counter % 100 == 0) { + usersUpdateStatement.executeBatch(); + appIdToUserIdUpdateStatement.executeBatch(); + } + } + usersUpdateStatement.executeBatch(); + appIdToUserIdUpdateStatement.executeBatch(); + } + public static void linkAccounts_Transaction(Start start, Connection sqlCon, AppIdentifier appIdentifier, String recipeUserId, String primaryUserId) throws SQLException, StorageQueryException { @@ -1212,6 +1300,73 @@ public static void linkAccounts_Transaction(Start start, Connection sqlCon, AppI } } + public static void linkMultipleAccounts_Transaction(Start start, Connection sqlCon, AppIdentifier appIdentifier, + Map recipeUserIdToPrimaryUserId) + throws SQLException, StorageQueryException { + + if(recipeUserIdToPrimaryUserId == null || recipeUserIdToPrimaryUserId.isEmpty()){ + return; + } + + String update_users_QUERY = "UPDATE " + getConfig(start).getUsersTable() + + " SET is_linked_or_is_a_primary_user = true, primary_or_recipe_user_id = ? WHERE app_id = ? AND " + + "user_id = ?"; + + String update_appid_to_userid_QUERY = "UPDATE " + getConfig(start).getAppIdToUserIdTable() + + " SET is_linked_or_is_a_primary_user = true, primary_or_recipe_user_id = ? WHERE app_id = ? AND " + + "user_id = ?"; + + PreparedStatement updateUsers = sqlCon.prepareStatement(update_users_QUERY); + PreparedStatement updateAppIdToUserId = sqlCon.prepareStatement(update_appid_to_userid_QUERY); + + int counter = 0; + for(Map.Entry linkEntry : recipeUserIdToPrimaryUserId.entrySet()) { + String primaryUserId = linkEntry.getValue(); + String recipeUserId = linkEntry.getKey(); + + updateUsers.setString(1, primaryUserId); + updateUsers.setString(2, appIdentifier.getAppId()); + updateUsers.setString(3, recipeUserId); + updateUsers.addBatch(); + + updateAppIdToUserId.setString(1, primaryUserId); + updateAppIdToUserId.setString(2, appIdentifier.getAppId()); + updateAppIdToUserId.setString(3, recipeUserId); + updateAppIdToUserId.addBatch(); + + counter++; + if (counter % 100 == 0) { + updateUsers.executeBatch(); + updateAppIdToUserId.executeBatch(); + } + } + + updateUsers.executeBatch(); + updateAppIdToUserId.executeBatch(); + + updateTimeJoinedForPrimaryUsers_Transaction(start, sqlCon, appIdentifier, + new ArrayList<>(recipeUserIdToPrimaryUserId.values())); + } + + public static void updateTimeJoinedForPrimaryUsers_Transaction(Start start, Connection sqlCon, + AppIdentifier appIdentifier, List primaryUserIds) + throws SQLException, StorageQueryException { + String QUERY = "UPDATE " + getConfig(start).getUsersTable() + + " SET primary_or_recipe_user_time_joined = (SELECT MIN(time_joined) FROM " + + getConfig(start).getUsersTable() + " WHERE app_id = ? AND primary_or_recipe_user_id = ?) WHERE " + + " app_id = ? AND primary_or_recipe_user_id = ?"; + PreparedStatement updateStatement = sqlCon.prepareStatement(QUERY); + for(String primaryUserId : primaryUserIds) { + updateStatement.setString(1, appIdentifier.getAppId()); + updateStatement.setString(2, primaryUserId); + updateStatement.setString(3, appIdentifier.getAppId()); + updateStatement.setString(4, primaryUserId); + updateStatement.addBatch(); + } + + updateStatement.executeBatch(); + } + public static void unlinkAccounts_Transaction(Start start, Connection sqlCon, AppIdentifier appIdentifier, String primaryUserId, String recipeUserId) throws SQLException, StorageQueryException { @@ -1343,6 +1498,39 @@ public static AuthRecipeUserInfo[] listPrimaryUsersByEmail_Transaction(Start sta return result.toArray(new AuthRecipeUserInfo[0]); } + public static AuthRecipeUserInfo[] listPrimaryUsersByMultipleEmailsOrPhonesOrThirdParty_Transaction(Start start, Connection sqlCon, + AppIdentifier appIdentifier, + List emails, List phones, + Map thirdpartyUserIdToThirdpartyId) + throws SQLException, StorageQueryException { + Set userIds = new HashSet<>(); + + //I am not really sure this is really needed.. + EmailPasswordQueries.lockEmail_Transaction(start, sqlCon, appIdentifier, emails); + ThirdPartyQueries.lockEmail_Transaction(start, sqlCon, appIdentifier, emails); + PasswordlessQueries.lockEmail_Transaction(start, sqlCon, appIdentifier, emails); + PasswordlessQueries.lockPhoneAndTenant_Transaction(start, sqlCon, appIdentifier, phones); + ThirdPartyQueries.lockThirdPartyInfoAndTenant_Transaction(start, sqlCon, appIdentifier, thirdpartyUserIdToThirdpartyId); + + //collect ids by email + userIds.addAll(EmailPasswordQueries.getPrimaryUserIdsUsingMultipleEmails_Transaction(start, sqlCon, appIdentifier, + emails)); + userIds.addAll(PasswordlessQueries.getPrimaryUserIdsUsingMultipleEmails_Transaction(start, sqlCon, appIdentifier, + emails)); + userIds.addAll(ThirdPartyQueries.getPrimaryUserIdsUsingMultipleEmails_Transaction(start, sqlCon, appIdentifier, emails)); + + //collect ids by phone + userIds.addAll(PasswordlessQueries.listUserIdsByMultiplePhoneNumber_Transaction(start, sqlCon, appIdentifier, phones)); + + //collect ids by thirdparty + userIds.addAll(ThirdPartyQueries.listUserIdsByMultipleThirdPartyInfo_Transaction(start, sqlCon, appIdentifier, thirdpartyUserIdToThirdpartyId)); + + List result = getPrimaryUserInfoForUserIds_Transaction(start, sqlCon, appIdentifier, + new ArrayList<>(userIds)); + + return result.toArray(new AuthRecipeUserInfo[0]); + } + public static AuthRecipeUserInfo[] listPrimaryUsersByEmail(Start start, TenantIdentifier tenantIdentifier, String email) throws StorageQueryException, SQLException { @@ -1443,6 +1631,17 @@ public static AuthRecipeUserInfo getPrimaryUserInfoForUserId_Transaction(Start s return result.get(0); } + public static List getPrimaryUserInfosForUserIds_Transaction(Start start, Connection con, + AppIdentifier appIdentifier, List ids) + throws SQLException, StorageQueryException { + + List result = getPrimaryUserInfoForUserIds_Transaction(start, con, appIdentifier, ids); + if (result.isEmpty()) { + return null; + } + return result; + } + private static List getPrimaryUserInfoForUserIds(Start start, AppIdentifier appIdentifier, List userIds) @@ -1553,16 +1752,18 @@ private static List getPrimaryUserInfoForUserIds_Transaction // column String QUERY = "SELECT au.user_id, au.primary_or_recipe_user_id, au.is_linked_or_is_a_primary_user, au.recipe_id, " + - "aaru.tenant_id, aaru.time_joined FROM " + - getConfig(start).getAppIdToUserIdTable() + " as au" + - " LEFT JOIN " + getConfig(start).getUsersTable() + - " as aaru ON au.app_id = aaru.app_id AND au.user_id = aaru.user_id" + - " WHERE au.primary_or_recipe_user_id IN (SELECT primary_or_recipe_user_id FROM " + - getConfig(start).getAppIdToUserIdTable() + " WHERE (user_id IN (" - + Utils.generateCommaSeperatedQuestionMarks(userIds.size()) + - ") OR primary_or_recipe_user_id IN (" + - Utils.generateCommaSeperatedQuestionMarks(userIds.size()) + - ")) AND app_id = ?) AND au.app_id = ?"; + "aaru.tenant_id, aaru.time_joined " + + "FROM " + getConfig(start).getAppIdToUserIdTable() + " as au" + + " LEFT JOIN " + getConfig(start).getUsersTable() + + " as aaru ON au.app_id = aaru.app_id AND au.user_id = aaru.user_id" + + " WHERE au.primary_or_recipe_user_id IN " + + " (SELECT primary_or_recipe_user_id FROM " + + getConfig(start).getAppIdToUserIdTable() + + " WHERE (user_id IN (" + + Utils.generateCommaSeperatedQuestionMarks(userIds.size()) +") " + + " OR primary_or_recipe_user_id IN (" + Utils.generateCommaSeperatedQuestionMarks(userIds.size()) +")) " + + " AND app_id = ?) " + + "AND au.app_id = ?"; List allAuthUsersResult = execute(sqlCon, QUERY, pst -> { // IN user_id @@ -1576,7 +1777,8 @@ private static List getPrimaryUserInfoForUserIds_Transaction } // for app_id pst.setString(index, appIdentifier.getAppId()); - pst.setString(index + 1, appIdentifier.getAppId()); + pst.setString(index+1, appIdentifier.getAppId()); +// System.out.println(pst); }, result -> { List parsedResult = new ArrayList<>(); while (result.next()) { diff --git a/src/main/java/io/supertokens/storage/postgresql/queries/PasswordlessQueries.java b/src/main/java/io/supertokens/storage/postgresql/queries/PasswordlessQueries.java index bfe1a2a9..6d3089cd 100644 --- a/src/main/java/io/supertokens/storage/postgresql/queries/PasswordlessQueries.java +++ b/src/main/java/io/supertokens/storage/postgresql/queries/PasswordlessQueries.java @@ -26,6 +26,7 @@ import io.supertokens.pluginInterface.multitenancy.TenantIdentifier; import io.supertokens.pluginInterface.passwordless.PasswordlessCode; import io.supertokens.pluginInterface.passwordless.PasswordlessDevice; +import io.supertokens.pluginInterface.passwordless.PasswordlessImportUser; import io.supertokens.pluginInterface.sqlStorage.SQLStorage.TransactionIsolationLevel; import io.supertokens.storage.postgresql.ConnectionPool; import io.supertokens.storage.postgresql.Start; @@ -35,6 +36,7 @@ import javax.annotation.Nonnull; import javax.annotation.Nullable; import java.sql.Connection; +import java.sql.PreparedStatement; import java.sql.ResultSet; import java.sql.SQLException; import java.util.*; @@ -812,6 +814,30 @@ public static List lockEmail_Transaction(Start start, Connection con, Ap }); } + public static List lockEmail_Transaction(Start start, Connection con, + AppIdentifier appIdentifier, + List emails) + throws StorageQueryException, SQLException { + if(emails == null || emails.isEmpty()){ + return new ArrayList<>(); + } + String QUERY = "SELECT user_id FROM " + getConfig(start).getPasswordlessUsersTable() + + " WHERE app_id = ? AND email IN (" + Utils.generateCommaSeperatedQuestionMarks(emails.size()) + ") FOR UPDATE"; + + return execute(con, QUERY, pst -> { + pst.setString(1, appIdentifier.getAppId()); + for (int i = 0; i < emails.size(); i++) { + pst.setString(2 + i, emails.get(i)); + } + }, result -> { + List results = new ArrayList<>(); + while (result.next()) { + results.add(result.getString("user_id")); + } + return results; + }); + } + public static List lockPhoneAndTenant_Transaction(Start start, Connection con, AppIdentifier appIdentifier, String phoneNumber) @@ -831,6 +857,30 @@ public static List lockPhoneAndTenant_Transaction(Start start, Connectio }); } + public static List lockPhoneAndTenant_Transaction(Start start, Connection con, + AppIdentifier appIdentifier, + List phones) + throws StorageQueryException, SQLException { + if(phones == null || phones.isEmpty()){ + return new ArrayList<>(); + } + String QUERY = "SELECT user_id FROM " + getConfig(start).getPasswordlessUsersTable() + + " WHERE app_id = ? AND phone_number IN (" + Utils.generateCommaSeperatedQuestionMarks(phones.size()) + ") FOR UPDATE"; + + return execute(con, QUERY, pst -> { + pst.setString(1, appIdentifier.getAppId()); + for (int i = 0; i < phones.size(); i++) { + pst.setString(2 + i, phones.get(i)); + } + }, result -> { + List results = new ArrayList<>(); + while (result.next()) { + results.add(result.getString("user_id")); + } + return results; + }); + } + public static String getPrimaryUserIdUsingEmail(Start start, TenantIdentifier tenantIdentifier, String email) throws StorageQueryException, SQLException { @@ -874,6 +924,33 @@ public static List getPrimaryUserIdsUsingEmail_Transaction(Start start, }); } + public static List getPrimaryUserIdsUsingMultipleEmails_Transaction(Start start, Connection con, + AppIdentifier appIdentifier, + List emails) + throws StorageQueryException, SQLException { + if(emails.isEmpty()){ + return new ArrayList<>(); + } + String QUERY = "SELECT DISTINCT all_users.primary_or_recipe_user_id AS user_id " + + "FROM " + getConfig(start).getPasswordlessUsersTable() + " AS ep" + + " JOIN " + getConfig(start).getAppIdToUserIdTable() + " AS all_users" + + " ON ep.app_id = all_users.app_id AND ep.user_id = all_users.user_id" + + " WHERE ep.app_id = ? AND ep.email IN ( " + Utils.generateCommaSeperatedQuestionMarks(emails.size()) + " )"; + + return execute(con, QUERY, pst -> { + pst.setString(1, appIdentifier.getAppId()); + for (int i = 0; i < emails.size(); i++) { + pst.setString(2+i, emails.get(i)); + } + }, result -> { + List userIds = new ArrayList<>(); + while (result.next()) { + userIds.add(result.getString("user_id")); + } + return userIds; + }); + } + public static String getPrimaryUserByPhoneNumber(Start start, TenantIdentifier tenantIdentifier, @Nonnull String phoneNumber) throws StorageQueryException, SQLException { @@ -917,6 +994,33 @@ public static List listUserIdsByPhoneNumber_Transaction(Start start, Con }); } + public static List listUserIdsByMultiplePhoneNumber_Transaction(Start start, Connection con, + AppIdentifier appIdentifier, + @Nonnull List phoneNumbers) + throws StorageQueryException, SQLException { + if(phoneNumbers.isEmpty()){ + return new ArrayList<>(); + } + String QUERY = "SELECT DISTINCT all_users.primary_or_recipe_user_id AS user_id " + + "FROM " + getConfig(start).getPasswordlessUsersTable() + " AS pless" + + " JOIN " + getConfig(start).getUsersTable() + " AS all_users" + + " ON pless.app_id = all_users.app_id AND pless.user_id = all_users.user_id" + + " WHERE pless.app_id = ? AND pless.phone_number IN ( "+ Utils.generateCommaSeperatedQuestionMarks(phoneNumbers.size()) +" )"; + + return execute(con, QUERY, pst -> { + pst.setString(1, appIdentifier.getAppId()); + for (int i = 0; i < phoneNumbers.size(); i++) { + pst.setString(2 + i, phoneNumbers.get(i)); + } + }, result -> { + List userIds = new ArrayList<>(); + while (result.next()) { + userIds.add(result.getString("user_id")); + } + return userIds; + }); + } + public static boolean addUserIdToTenant_Transaction(Start start, Connection sqlCon, TenantIdentifier tenantIdentifier, String userId) throws StorageQueryException, SQLException, UnknownUserIdException { @@ -1110,6 +1214,76 @@ private static List fillUserInfoWithTenantIds(Start start, return userInfos; } + public static void importUsers_Transaction(Connection sqlCon, Start start, + Collection users) throws SQLException { + + String app_id_to_user_id_QUERY = "INSERT INTO " + getConfig(start).getAppIdToUserIdTable() + + "(app_id, user_id, primary_or_recipe_user_id, recipe_id)" + " VALUES(?, ?, ?, ?)"; + PreparedStatement appIdToUserIdStatement = sqlCon.prepareStatement(app_id_to_user_id_QUERY); + + String all_auth_recipe_users_QUERY = "INSERT INTO " + getConfig(start).getUsersTable() + + "(app_id, tenant_id, user_id, primary_or_recipe_user_id, recipe_id, time_joined, " + + "primary_or_recipe_user_time_joined)" + + " VALUES(?, ?, ?, ?, ?, ?, ?)"; + PreparedStatement allAuthRecipeUsersStatement = sqlCon.prepareStatement(all_auth_recipe_users_QUERY); + + String passwordless_users_QUERY = "INSERT INTO " + getConfig(start).getPasswordlessUsersTable() + + "(app_id, user_id, email, phone_number, time_joined)" + " VALUES(?, ?, ?, ?, ?)"; + PreparedStatement passwordlessUsersStatement = sqlCon.prepareStatement(passwordless_users_QUERY); + + String passwordless_user_to_tenant_QUERY = "INSERT INTO " + getConfig(start).getPasswordlessUserToTenantTable() + + "(app_id, tenant_id, user_id, email, phone_number)" + " VALUES(?, ?, ?, ?, ?)"; + PreparedStatement passwordlessUserToTenantStatement = sqlCon.prepareStatement(passwordless_user_to_tenant_QUERY); + + int counter = 0; + for (PasswordlessImportUser user: users){ + TenantIdentifier tenantIdentifier = user.tenantIdentifier; + appIdToUserIdStatement.setString(1, tenantIdentifier.getAppId()); + appIdToUserIdStatement.setString(2, user.userId); + appIdToUserIdStatement.setString(3, user.userId); + appIdToUserIdStatement.setString(4, PASSWORDLESS.toString()); + appIdToUserIdStatement.addBatch(); + + allAuthRecipeUsersStatement.setString(1, tenantIdentifier.getAppId()); + allAuthRecipeUsersStatement.setString(2, tenantIdentifier.getTenantId()); + allAuthRecipeUsersStatement.setString(3, user.userId); + allAuthRecipeUsersStatement.setString(4, user.userId); + allAuthRecipeUsersStatement.setString(5, PASSWORDLESS.toString()); + allAuthRecipeUsersStatement.setLong(6, user.timeJoinedMSSinceEpoch); + allAuthRecipeUsersStatement.setLong(7, user.timeJoinedMSSinceEpoch); + allAuthRecipeUsersStatement.addBatch(); + + passwordlessUsersStatement.setString(1, tenantIdentifier.getAppId()); + passwordlessUsersStatement.setString(2, user.userId); + passwordlessUsersStatement.setString(3, user.email); + passwordlessUsersStatement.setString(4, user.phoneNumber); + passwordlessUsersStatement.setLong(5, user.timeJoinedMSSinceEpoch); + passwordlessUsersStatement.addBatch(); + + passwordlessUserToTenantStatement.setString(1, tenantIdentifier.getAppId()); + passwordlessUserToTenantStatement.setString(2, tenantIdentifier.getTenantId()); + passwordlessUserToTenantStatement.setString(3, user.userId); + passwordlessUserToTenantStatement.setString(4, user.email); + passwordlessUserToTenantStatement.setString(5, user.phoneNumber); + passwordlessUserToTenantStatement.addBatch(); + + counter++; + + if(counter % 100 == 0) { + appIdToUserIdStatement.executeBatch(); + allAuthRecipeUsersStatement.executeBatch(); + passwordlessUsersStatement.executeBatch(); + passwordlessUserToTenantStatement.executeBatch(); + } + } + + appIdToUserIdStatement.executeBatch(); + allAuthRecipeUsersStatement.executeBatch(); + passwordlessUsersStatement.executeBatch(); + passwordlessUserToTenantStatement.executeBatch(); + + } + private static class PasswordlessDeviceRowMapper implements RowMapper { private static final PasswordlessDeviceRowMapper INSTANCE = new PasswordlessDeviceRowMapper(); diff --git a/src/main/java/io/supertokens/storage/postgresql/queries/SessionQueries.java b/src/main/java/io/supertokens/storage/postgresql/queries/SessionQueries.java index 0fe56e4d..c5108fd4 100644 --- a/src/main/java/io/supertokens/storage/postgresql/queries/SessionQueries.java +++ b/src/main/java/io/supertokens/storage/postgresql/queries/SessionQueries.java @@ -33,7 +33,9 @@ import java.sql.ResultSet; import java.sql.SQLException; import java.util.ArrayList; +import java.util.HashMap; import java.util.List; +import java.util.Map; import static io.supertokens.storage.postgresql.QueryExecutorTemplate.execute; import static io.supertokens.storage.postgresql.QueryExecutorTemplate.update; @@ -306,6 +308,31 @@ public static String[] getAllNonExpiredSessionHandlesForUser(Start start, AppIde }); } + public static Map> getAllNonExpiredSessionHandlesForUsers(Start start, AppIdentifier appIdentifier, + List userIds) + throws SQLException, StorageQueryException { + String QUERY = "SELECT user_id, session_handle FROM " + getConfig(start).getSessionInfoTable() + + " WHERE app_id = ? AND expires_at >= ? AND user_id IN ( " + Utils.generateCommaSeperatedQuestionMarks(userIds.size()) + " )"; + + return execute(start, QUERY, pst -> { + pst.setString(1, appIdentifier.getAppId()); + pst.setLong(2, currentTimeMillis()); + for(int i = 0; i < userIds.size() ; i++){ + pst.setString(3 + i, userIds.get(i)); + } + }, result -> { + Map> temp = new HashMap<>(); + while (result.next()) { + String userId = result.getString("user_id"); + if(!temp.containsKey(userId)){ + temp.put(userId, new ArrayList<>()); + } + temp.get(userId).add(result.getString("session_handle")); + } + return temp; + }); + } + public static void deleteAllExpiredSessions(Start start) throws SQLException, StorageQueryException { String QUERY = "DELETE FROM " + getConfig(start).getSessionInfoTable() + " WHERE expires_at <= ?"; diff --git a/src/main/java/io/supertokens/storage/postgresql/queries/TOTPQueries.java b/src/main/java/io/supertokens/storage/postgresql/queries/TOTPQueries.java index 60270a65..c14aa15d 100644 --- a/src/main/java/io/supertokens/storage/postgresql/queries/TOTPQueries.java +++ b/src/main/java/io/supertokens/storage/postgresql/queries/TOTPQueries.java @@ -1,22 +1,24 @@ package io.supertokens.storage.postgresql.queries; -import java.sql.Connection; -import java.sql.ResultSet; -import java.sql.SQLException; -import java.util.ArrayList; -import java.util.List; - -import io.supertokens.pluginInterface.multitenancy.AppIdentifier; -import io.supertokens.pluginInterface.multitenancy.TenantIdentifier; -import io.supertokens.storage.postgresql.Start; -import io.supertokens.storage.postgresql.config.Config; import io.supertokens.pluginInterface.RowMapper; import io.supertokens.pluginInterface.exceptions.StorageQueryException; -import io.supertokens.pluginInterface.exceptions.StorageTransactionLogicException; +import io.supertokens.pluginInterface.multitenancy.AppIdentifier; +import io.supertokens.pluginInterface.multitenancy.TenantIdentifier; import io.supertokens.pluginInterface.totp.TOTPDevice; import io.supertokens.pluginInterface.totp.TOTPUsedCode; +import io.supertokens.storage.postgresql.Start; +import io.supertokens.storage.postgresql.config.Config; import io.supertokens.storage.postgresql.utils.Utils; +import java.sql.Connection; +import java.sql.PreparedStatement; +import java.sql.ResultSet; +import java.sql.SQLException; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + import static io.supertokens.storage.postgresql.QueryExecutorTemplate.execute; import static io.supertokens.storage.postgresql.QueryExecutorTemplate.update; @@ -146,6 +148,52 @@ public static void createDevice_Transaction(Start start, Connection sqlCon, AppI insertDevice_Transaction(start, sqlCon, appIdentifier, device); } + public static void createDevices_Transaction(Start start, Connection sqlCon, AppIdentifier appIdentifier, + List devices) + throws SQLException, StorageQueryException { + + String insert_user_QUERY = "INSERT INTO " + Config.getConfig(start).getTotpUsersTable() + + " (app_id, user_id) VALUES (?, ?) ON CONFLICT DO NOTHING"; + + String insert_device_QUERY = "INSERT INTO " + Config.getConfig(start).getTotpUserDevicesTable() + + + " (app_id, user_id, device_name, secret_key, period, skew, verified, created_at) VALUES (?, ?, ?, ?, " + + "?, ?, ?, ?) ON CONFLICT (app_id, user_id, device_name) DO UPDATE SET secret_key = ?, period = ?, skew = ?, created_at = ?, verified = ?"; + + PreparedStatement insertUserStatement = sqlCon.prepareStatement(insert_user_QUERY); + PreparedStatement insertDeviceStatement = sqlCon.prepareStatement(insert_device_QUERY); + + int counter = 0; + for(TOTPDevice device : devices){ + insertUserStatement.setString(1, appIdentifier.getAppId()); + insertUserStatement.setString(2, device.userId); + insertUserStatement.addBatch(); + + insertDeviceStatement.setString(1, appIdentifier.getAppId()); + insertDeviceStatement.setString(2, device.userId); + insertDeviceStatement.setString(3, device.deviceName); + insertDeviceStatement.setString(4, device.secretKey); + insertDeviceStatement.setInt(5, device.period); + insertDeviceStatement.setInt(6, device.skew); + insertDeviceStatement.setBoolean(7, device.verified); + insertDeviceStatement.setLong(8, device.createdAt); + insertDeviceStatement.setString(9, device.secretKey); + insertDeviceStatement.setInt(10, device.period); + insertDeviceStatement.setInt(11, device.skew); + insertDeviceStatement.setLong(12, device.createdAt); + insertDeviceStatement.setBoolean(13, device.verified); + insertDeviceStatement.addBatch(); + counter++; + if(counter % 100 == 0) { + insertUserStatement.executeBatch(); + insertDeviceStatement.executeBatch(); + } + } + + insertUserStatement.executeBatch(); + insertDeviceStatement.executeBatch(); + } + public static TOTPDevice getDeviceByName_Transaction(Start start, Connection sqlCon, AppIdentifier appIdentifier, String userId, String deviceName) throws SQLException, StorageQueryException { @@ -245,6 +293,30 @@ public static TOTPDevice[] getDevices(Start start, AppIdentifier appIdentifier, }); } + public static Map> getDevicesForMultipleUsers(Start start, AppIdentifier appIdentifier, List userIds) + throws StorageQueryException, SQLException { + String QUERY = "SELECT * FROM " + Config.getConfig(start).getTotpUserDevicesTable() + + " WHERE app_id = ? AND user_id IN (" + Utils.generateCommaSeperatedQuestionMarks(userIds.size()) + ");"; + + return execute(start, QUERY, pst -> { + pst.setString(1, appIdentifier.getAppId()); + for(int i = 0; i < userIds.size(); i++) { + pst.setString(2+i, userIds.get(i)); + } + }, result -> { + Map> devicesByUserIds = new HashMap<>(); + while (result.next()) { + String userId = result.getString("user_id"); + if (!devicesByUserIds.containsKey(userId)){ + devicesByUserIds.put(userId, new ArrayList<>()); + } + devicesByUserIds.get(userId).add(TOTPDeviceRowMapper.getInstance().map(result)); + } + + return devicesByUserIds; + }); + } + public static TOTPDevice[] getDevices_Transaction(Start start, Connection con, AppIdentifier appIdentifier, String userId) throws StorageQueryException, SQLException { diff --git a/src/main/java/io/supertokens/storage/postgresql/queries/ThirdPartyQueries.java b/src/main/java/io/supertokens/storage/postgresql/queries/ThirdPartyQueries.java index 964b53cd..afe63b55 100644 --- a/src/main/java/io/supertokens/storage/postgresql/queries/ThirdPartyQueries.java +++ b/src/main/java/io/supertokens/storage/postgresql/queries/ThirdPartyQueries.java @@ -19,19 +19,19 @@ import io.supertokens.pluginInterface.RowMapper; import io.supertokens.pluginInterface.authRecipe.AuthRecipeUserInfo; import io.supertokens.pluginInterface.authRecipe.LoginMethod; -import io.supertokens.pluginInterface.emailpassword.exceptions.DuplicateEmailException; import io.supertokens.pluginInterface.emailpassword.exceptions.UnknownUserIdException; import io.supertokens.pluginInterface.exceptions.StorageQueryException; import io.supertokens.pluginInterface.exceptions.StorageTransactionLogicException; import io.supertokens.pluginInterface.multitenancy.AppIdentifier; import io.supertokens.pluginInterface.multitenancy.TenantIdentifier; -import io.supertokens.pluginInterface.thirdparty.exception.DuplicateThirdPartyUserException; +import io.supertokens.pluginInterface.thirdparty.ThirdPartyImportUser; import io.supertokens.storage.postgresql.ConnectionPool; import io.supertokens.storage.postgresql.Start; import io.supertokens.storage.postgresql.config.Config; import io.supertokens.storage.postgresql.utils.Utils; import java.sql.Connection; +import java.sql.PreparedStatement; import java.sql.ResultSet; import java.sql.SQLException; import java.util.*; @@ -222,6 +222,30 @@ public static List lockEmail_Transaction(Start start, Connection con, }); } + public static List lockEmail_Transaction(Start start, Connection con, + AppIdentifier appIdentifier, + List emails) + throws StorageQueryException, SQLException { + if(emails == null || emails.isEmpty()){ + return new ArrayList<>(); + } + String QUERY = "SELECT user_id FROM " + getConfig(start).getThirdPartyUsersTable() + + " WHERE app_id = ? AND email IN (" + Utils.generateCommaSeperatedQuestionMarks(emails.size()) + ") FOR UPDATE"; + + return execute(con, QUERY, pst -> { + pst.setString(1, appIdentifier.getAppId()); + for (int i = 0; i < emails.size(); i++) { + pst.setString(2 + i, emails.get(i)); + } + }, result -> { + List results = new ArrayList<>(); + while (result.next()) { + results.add(result.getString("user_id")); + } + return results; + }); + } + public static List lockThirdPartyInfoAndTenant_Transaction(Start start, Connection con, AppIdentifier appIdentifier, String thirdPartyId, String thirdPartyUserId) @@ -243,6 +267,38 @@ public static List lockThirdPartyInfoAndTenant_Transaction(Start start, }); } + public static List lockThirdPartyInfoAndTenant_Transaction(Start start, Connection con, + AppIdentifier appIdentifier, + Map thirdPartyUserIdToThirdPartyId) + throws SQLException, StorageQueryException { + if(thirdPartyUserIdToThirdPartyId == null || thirdPartyUserIdToThirdPartyId.isEmpty()) { + return new ArrayList<>(); + } + + String QUERY = "SELECT user_id " + + " FROM " + getConfig(start).getThirdPartyUsersTable() + + " WHERE app_id = ? AND third_party_id IN ("+Utils.generateCommaSeperatedQuestionMarks( + thirdPartyUserIdToThirdPartyId.size())+") AND third_party_user_id IN ("+ + Utils.generateCommaSeperatedQuestionMarks(thirdPartyUserIdToThirdPartyId.size())+") FOR UPDATE"; + + return execute(con, QUERY, pst -> { + pst.setString(1, appIdentifier.getAppId()); + int counter = 2; + for (String thirdPartyId : thirdPartyUserIdToThirdPartyId.values()){ + pst.setString(counter++, thirdPartyId); + } + for (String thirdPartyUserId : thirdPartyUserIdToThirdPartyId.keySet()) { + pst.setString(counter++, thirdPartyUserId); + } + }, result -> { + List finalResult = new ArrayList<>(); + while (result.next()) { + finalResult.add(result.getString("user_id")); + } + return finalResult; + }); + } + public static List getUsersInfoUsingIdList(Start start, Set ids, AppIdentifier appIdentifier) throws SQLException, StorageQueryException { @@ -353,6 +409,41 @@ public static List listUserIdsByThirdPartyInfo_Transaction(Start start, }); } + public static List listUserIdsByMultipleThirdPartyInfo_Transaction(Start start, Connection con, + AppIdentifier appIdentifier, + Map thirdPartyUserIdToThirdPartyId) + throws SQLException, StorageQueryException { + if(thirdPartyUserIdToThirdPartyId.isEmpty()){ + return new ArrayList<>(); + } + String QUERY = "SELECT DISTINCT all_users.primary_or_recipe_user_id AS user_id " + + "FROM " + getConfig(start).getThirdPartyUsersTable() + " AS tp" + + " JOIN " + getConfig(start).getUsersTable() + " AS all_users" + + " ON tp.app_id = all_users.app_id AND tp.user_id = all_users.user_id" + + " WHERE tp.app_id = ? AND tp.third_party_id IN ( " + Utils.generateCommaSeperatedQuestionMarks( + thirdPartyUserIdToThirdPartyId.size()) + " ) AND tp.third_party_user_id IN ( " + Utils.generateCommaSeperatedQuestionMarks( + thirdPartyUserIdToThirdPartyId.size()) + " )"; + + return execute(con, QUERY, pst -> { + pst.setString(1, appIdentifier.getAppId()); + int counter = 2; + for (String thirdpartId : thirdPartyUserIdToThirdPartyId.values()){ + pst.setString(counter, thirdpartId); + counter++; + } + for (String thirdparyUserId : thirdPartyUserIdToThirdPartyId.keySet()){ + pst.setString(counter, thirdparyUserId); + counter++; + } + }, result -> { + List userIds = new ArrayList<>(); + while (result.next()) { + userIds.add(result.getString("user_id")); + } + return userIds; + }); + } + public static String getUserIdByThirdPartyInfo(Start start, TenantIdentifier tenantIdentifier, String thirdPartyId, String thirdPartyUserId) throws SQLException, StorageQueryException { @@ -455,6 +546,33 @@ public static List getPrimaryUserIdUsingEmail_Transaction(Start start, C }); } + public static List getPrimaryUserIdsUsingMultipleEmails_Transaction(Start start, Connection con, + AppIdentifier appIdentifier, + List emails) + throws StorageQueryException, SQLException { + if(emails.isEmpty()){ + return new ArrayList<>(); + } + String QUERY = "SELECT DISTINCT all_users.primary_or_recipe_user_id AS user_id " + + "FROM " + getConfig(start).getThirdPartyUsersTable() + " AS ep" + + " JOIN " + getConfig(start).getAppIdToUserIdTable() + " AS all_users" + + " ON ep.app_id = all_users.app_id AND ep.user_id = all_users.user_id" + + " WHERE ep.app_id = ? AND ep.email IN ( " + Utils.generateCommaSeperatedQuestionMarks(emails.size()) + " )"; + + return execute(con, QUERY, pst -> { + pst.setString(1, appIdentifier.getAppId()); + for (int i = 0; i < emails.size(); i++) { + pst.setString(2+i, emails.get(i)); + } + }, result -> { + List userIds = new ArrayList<>(); + while (result.next()) { + userIds.add(result.getString("user_id")); + } + return userIds; + }); + } + public static boolean addUserIdToTenant_Transaction(Start start, Connection sqlCon, TenantIdentifier tenantIdentifier, String userId) throws SQLException, StorageQueryException, UnknownUserIdException { @@ -527,6 +645,80 @@ public static boolean removeUserIdFromTenant_Transaction(Start start, Connection // automatically deleted from thirdparty_user_to_tenant because of foreign key constraint } + public static void importUser_Transaction(Start start, Connection sqlConnection, Collection users) + throws SQLException { + + String app_id_userid_QUERY = "INSERT INTO " + getConfig(start).getAppIdToUserIdTable() + + "(app_id, user_id, primary_or_recipe_user_id, recipe_id)" + " VALUES(?, ?, ?, ?)"; + + String all_auth_recipe_users_QUERY = "INSERT INTO " + getConfig(start).getUsersTable() + + + "(app_id, tenant_id, user_id, primary_or_recipe_user_id, recipe_id, time_joined, " + + "primary_or_recipe_user_time_joined)" + + " VALUES(?, ?, ?, ?, ?, ?, ?)"; + + String thirdparty_users_QUERY = "INSERT INTO " + getConfig(start).getThirdPartyUsersTable() + + "(app_id, third_party_id, third_party_user_id, user_id, email, time_joined)" + + " VALUES(?, ?, ?, ?, ?, ?)"; + + String thirdparty_user_to_tenant_QUERY = "INSERT INTO " + getConfig(start).getThirdPartyUserToTenantTable() + + "(app_id, tenant_id, user_id, third_party_id, third_party_user_id)" + + " VALUES(?, ?, ?, ?, ?)"; + + PreparedStatement appIdToUserIdStatement = sqlConnection.prepareStatement(app_id_userid_QUERY); + PreparedStatement allAuthRecipeUsersStatement = sqlConnection.prepareStatement(all_auth_recipe_users_QUERY); + PreparedStatement thirdPartyUsersStatement = sqlConnection.prepareStatement(thirdparty_users_QUERY); + PreparedStatement thirdPartyUsersToTenantStatement = sqlConnection.prepareStatement( + thirdparty_user_to_tenant_QUERY); + + int counter = 0; + for (ThirdPartyImportUser user : users) { + TenantIdentifier tenantIdentifier = user.tenantIdentifier; + appIdToUserIdStatement.setString(1, tenantIdentifier.getAppId()); + appIdToUserIdStatement.setString(2, user.userId); + appIdToUserIdStatement.setString(3, user.userId); + appIdToUserIdStatement.setString(4, THIRD_PARTY.toString()); + appIdToUserIdStatement.addBatch(); + + allAuthRecipeUsersStatement.setString(1, tenantIdentifier.getAppId()); + allAuthRecipeUsersStatement.setString(2, tenantIdentifier.getTenantId()); + allAuthRecipeUsersStatement.setString(3, user.userId); + allAuthRecipeUsersStatement.setString(4, user.userId); + allAuthRecipeUsersStatement.setString(5, THIRD_PARTY.toString()); + allAuthRecipeUsersStatement.setLong(6, user.timeJoinedMSSinceEpoch); + allAuthRecipeUsersStatement.setLong(7, user.timeJoinedMSSinceEpoch); + allAuthRecipeUsersStatement.addBatch(); + + thirdPartyUsersStatement.setString(1, tenantIdentifier.getAppId()); + thirdPartyUsersStatement.setString(2, user.thirdpartyId); + thirdPartyUsersStatement.setString(3, user.thirdpartyUserId); + thirdPartyUsersStatement.setString(4, user.userId); + thirdPartyUsersStatement.setString(5, user.email); + thirdPartyUsersStatement.setLong(6, user.timeJoinedMSSinceEpoch); + thirdPartyUsersStatement.addBatch(); + + thirdPartyUsersToTenantStatement.setString(1, tenantIdentifier.getAppId()); + thirdPartyUsersToTenantStatement.setString(2, tenantIdentifier.getTenantId()); + thirdPartyUsersToTenantStatement.setString(3, user.userId); + thirdPartyUsersToTenantStatement.setString(4, user.thirdpartyId); + thirdPartyUsersToTenantStatement.setString(5, user.thirdpartyUserId); + thirdPartyUsersToTenantStatement.addBatch(); + + counter++; + if(counter % 100 == 0) { + appIdToUserIdStatement.executeBatch(); + allAuthRecipeUsersStatement.executeBatch(); + thirdPartyUsersStatement.executeBatch(); + thirdPartyUsersToTenantStatement.executeBatch(); + } + } + + appIdToUserIdStatement.executeBatch(); + allAuthRecipeUsersStatement.executeBatch(); + thirdPartyUsersStatement.executeBatch(); + thirdPartyUsersToTenantStatement.executeBatch(); + } + private static UserInfoPartial fillUserInfoWithVerified_transaction(Start start, Connection sqlCon, AppIdentifier appIdentifier, UserInfoPartial userInfo) diff --git a/src/main/java/io/supertokens/storage/postgresql/queries/UserIdMappingQueries.java b/src/main/java/io/supertokens/storage/postgresql/queries/UserIdMappingQueries.java index 072571c1..48ea4dd1 100644 --- a/src/main/java/io/supertokens/storage/postgresql/queries/UserIdMappingQueries.java +++ b/src/main/java/io/supertokens/storage/postgresql/queries/UserIdMappingQueries.java @@ -20,17 +20,20 @@ import io.supertokens.pluginInterface.exceptions.StorageQueryException; import io.supertokens.pluginInterface.multitenancy.AppIdentifier; import io.supertokens.pluginInterface.useridmapping.UserIdMapping; +import io.supertokens.storage.postgresql.ConnectionPool; import io.supertokens.storage.postgresql.Start; import io.supertokens.storage.postgresql.config.Config; import io.supertokens.storage.postgresql.utils.Utils; import javax.annotation.Nullable; import java.sql.Connection; +import java.sql.PreparedStatement; import java.sql.ResultSet; import java.sql.SQLException; import java.util.ArrayList; import java.util.HashMap; import java.util.List; +import java.util.Map; import static io.supertokens.storage.postgresql.QueryExecutorTemplate.execute; import static io.supertokens.storage.postgresql.QueryExecutorTemplate.update; @@ -80,6 +83,30 @@ public static void createUserIdMapping(Start start, AppIdentifier appIdentifier, }); } + public static void createBulkUserIdMapping(Start start, AppIdentifier appIdentifier, + Map superTokensUserIdToExternalUserId) + throws SQLException, StorageQueryException { + String QUERY = "INSERT INTO " + Config.getConfig(start).getUserIdMappingTable() + + " (app_id, supertokens_user_id, external_user_id)" + " VALUES(?, ?, ?)"; + + Connection sqlConnection = ConnectionPool.getConnection(start); + PreparedStatement insertStatement = sqlConnection.prepareStatement(QUERY); + + int counter = 0; + for(String superTokensUserId : superTokensUserIdToExternalUserId.keySet()) { + insertStatement.setString(1, appIdentifier.getAppId()); + insertStatement.setString(2, superTokensUserId); + insertStatement.setString(3, superTokensUserIdToExternalUserId.get(superTokensUserId)); + insertStatement.addBatch(); + + counter++; + if(counter % 100 == 0) { + insertStatement.executeBatch(); + } + } + insertStatement.executeBatch(); + } + public static UserIdMapping getuseraIdMappingWithSuperTokensUserId(Start start, AppIdentifier appIdentifier, String userId) throws SQLException, StorageQueryException { @@ -304,6 +331,50 @@ public static UserIdMapping getUserIdMappingWithExternalUserId_Transaction(Start }); } + public static List getMultipleUserIdMappingWithExternalUserId_Transaction(Start start, Connection sqlCon, + AppIdentifier appIdentifier, + List userId) + throws SQLException, StorageQueryException { + String QUERY = "SELECT * FROM " + Config.getConfig(start).getUserIdMappingTable() + + " WHERE app_id = ? AND external_user_id IN ( "+ Utils.generateCommaSeperatedQuestionMarks( + userId.size()) + " )"; + + return execute(sqlCon, QUERY, pst -> { + pst.setString(1, appIdentifier.getAppId()); + for(int i = 0; i < userId.size(); i++) { + pst.setString(2 + i, userId.get(i)); + } + }, result -> { + List results = new ArrayList<>(); + while (result.next()) { + results.add(UserIdMappingRowMapper.getInstance().mapOrThrow(result)); + } + return results; + }); + } + + public static List getMultipleUserIdMappingWithSupertokensUserId_Transaction(Start start, Connection sqlCon, + AppIdentifier appIdentifier, + List userId) + throws SQLException, StorageQueryException { + String QUERY = "SELECT * FROM " + Config.getConfig(start).getUserIdMappingTable() + + " WHERE app_id = ? AND supertokens_user_id IN ( "+ Utils.generateCommaSeperatedQuestionMarks( + userId.size()) + " )"; + + return execute(sqlCon, QUERY, pst -> { + pst.setString(1, appIdentifier.getAppId()); + for(int i = 0; i < userId.size(); i++) { + pst.setString(2 + i, userId.get(i)); + } + }, result -> { + List results = new ArrayList<>(); + while (result.next()) { + results.add(UserIdMappingRowMapper.getInstance().mapOrThrow(result)); + } + return results; + }); + } + public static UserIdMapping[] getUserIdMappingWithEitherSuperTokensUserIdOrExternalUserId_Transaction(Start start, Connection sqlCon, AppIdentifier appIdentifier, diff --git a/src/main/java/io/supertokens/storage/postgresql/queries/UserMetadataQueries.java b/src/main/java/io/supertokens/storage/postgresql/queries/UserMetadataQueries.java index 1d2b6231..d4fde3d7 100644 --- a/src/main/java/io/supertokens/storage/postgresql/queries/UserMetadataQueries.java +++ b/src/main/java/io/supertokens/storage/postgresql/queries/UserMetadataQueries.java @@ -19,13 +19,18 @@ import com.google.gson.JsonObject; import com.google.gson.JsonParser; import io.supertokens.pluginInterface.exceptions.StorageQueryException; +import io.supertokens.pluginInterface.exceptions.StorageTransactionLogicException; import io.supertokens.pluginInterface.multitenancy.AppIdentifier; import io.supertokens.storage.postgresql.Start; import io.supertokens.storage.postgresql.config.Config; import io.supertokens.storage.postgresql.utils.Utils; import java.sql.Connection; +import java.sql.PreparedStatement; import java.sql.SQLException; +import java.util.HashMap; +import java.util.List; +import java.util.Map; import static io.supertokens.storage.postgresql.QueryExecutorTemplate.execute; import static io.supertokens.storage.postgresql.QueryExecutorTemplate.update; @@ -93,6 +98,31 @@ public static int setUserMetadata_Transaction(Start start, Connection con, AppId }); } + public static void setMultipleUsersMetadatas_Transaction(Start start, Connection con, AppIdentifier appIdentifier, + Map metadatasByUserId) + throws SQLException, StorageQueryException { + + String QUERY = "INSERT INTO " + getConfig(start).getUserMetadataTable() + + "(app_id, user_id, user_metadata) VALUES(?, ?, ?) " + + "ON CONFLICT(app_id, user_id) DO UPDATE SET user_metadata=excluded.user_metadata;"; + PreparedStatement insertStatement = con.prepareStatement(QUERY); + + int counter = 0; + for(Map.Entry metadataByUserId : metadatasByUserId.entrySet()){ + insertStatement.setString(1, appIdentifier.getAppId()); + insertStatement.setString(2, metadataByUserId.getKey()); + insertStatement.setString(3, metadataByUserId.getValue().toString()); + insertStatement.addBatch(); + + counter++; + if(counter % 100 == 0) { + insertStatement.executeBatch(); + } + } + + insertStatement.executeBatch(); + } + public static JsonObject getUserMetadata_Transaction(Start start, Connection con, AppIdentifier appIdentifier, String userId) throws SQLException, StorageQueryException { @@ -110,6 +140,28 @@ public static JsonObject getUserMetadata_Transaction(Start start, Connection con }); } + public static Map getMultipleUsersMetadatas_Transaction(Start start, Connection con, AppIdentifier appIdentifier, + List userIds) + throws SQLException, StorageQueryException { + String QUERY = "SELECT user_id, user_metadata FROM " + getConfig(start).getUserMetadataTable() + + " WHERE app_id = ? AND user_id IN (" + Utils.generateCommaSeperatedQuestionMarks(userIds.size()) + + ") FOR UPDATE"; + return execute(con, QUERY, pst -> { + pst.setString(1, appIdentifier.getAppId()); + for (int i = 0; i< userIds.size(); i++){ + pst.setString(2+i, userIds.get(i)); + } + }, result -> { + Map userMetadataByUserId = new HashMap<>(); + JsonParser jp = new JsonParser(); + if (result.next()) { + userMetadataByUserId.put(result.getString("user_id"), + jp.parse(result.getString("user_metadata")).getAsJsonObject()); + } + return userMetadataByUserId; + }); + } + public static JsonObject getUserMetadata(Start start, AppIdentifier appIdentifier, String userId) throws SQLException, StorageQueryException { String QUERY = "SELECT user_metadata FROM " + getConfig(start).getUserMetadataTable() @@ -125,4 +177,11 @@ public static JsonObject getUserMetadata(Start start, AppIdentifier appIdentifie return null; }); } + + public static Map getMultipleUserMetadatas(Start start, AppIdentifier appIdentifier, List userIds) + throws StorageQueryException, StorageTransactionLogicException { + return start.startTransaction(con -> { + return getMultipleUsersMetadatas_Transaction(start, (Connection) con.getConnection(), appIdentifier, userIds); + }); + } } diff --git a/src/main/java/io/supertokens/storage/postgresql/queries/UserRolesQueries.java b/src/main/java/io/supertokens/storage/postgresql/queries/UserRolesQueries.java index 5825164c..807c8bf3 100644 --- a/src/main/java/io/supertokens/storage/postgresql/queries/UserRolesQueries.java +++ b/src/main/java/io/supertokens/storage/postgresql/queries/UserRolesQueries.java @@ -17,18 +17,20 @@ package io.supertokens.storage.postgresql.queries; import io.supertokens.pluginInterface.exceptions.StorageQueryException; -import io.supertokens.pluginInterface.exceptions.StorageTransactionLogicException; import io.supertokens.pluginInterface.multitenancy.AppIdentifier; import io.supertokens.pluginInterface.multitenancy.TenantIdentifier; -import io.supertokens.pluginInterface.sqlStorage.TransactionConnection; import io.supertokens.storage.postgresql.Start; import io.supertokens.storage.postgresql.config.Config; import io.supertokens.storage.postgresql.utils.Utils; import java.sql.Connection; +import java.sql.PreparedStatement; import java.sql.ResultSet; import java.sql.SQLException; import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; import static io.supertokens.storage.postgresql.QueryExecutorTemplate.execute; import static io.supertokens.storage.postgresql.QueryExecutorTemplate.update; @@ -202,6 +204,33 @@ public static int addRoleToUser(Start start, TenantIdentifier tenantIdentifier, }); } + public static void addRolesToUsers_Transaction(Start start, Connection connection, Map>> rolesToUserByTenants) //tenant -> user -> role + throws SQLException, StorageQueryException { + String QUERY = "INSERT INTO " + getConfig(start).getUserRolesTable() + + "(app_id, tenant_id, user_id, role) VALUES(?, ?, ?, ?);"; + PreparedStatement insertStatement = connection.prepareStatement(QUERY); + + int counter = 0; + for(Map.Entry>> tenantsEntry : rolesToUserByTenants.entrySet()) { + for(Map.Entry> rolesToUser : tenantsEntry.getValue().entrySet()) { + for(String roleForUser : rolesToUser.getValue()){ + insertStatement.setString(1, tenantsEntry.getKey().getAppId()); + insertStatement.setString(2, tenantsEntry.getKey().getTenantId()); + insertStatement.setString(3, rolesToUser.getKey()); + insertStatement.setString(4, roleForUser); + insertStatement.addBatch(); + counter++; + + if(counter % 100 == 0) { + insertStatement.executeBatch(); + } + } + } + } + + insertStatement.executeBatch(); + } + public static String[] getRolesForUser(Start start, TenantIdentifier tenantIdentifier, String userId) throws SQLException, StorageQueryException { String QUERY = "SELECT role FROM " + getConfig(start).getUserRolesTable() @@ -237,6 +266,29 @@ public static String[] getRolesForUser(Start start, AppIdentifier appIdentifier, }); } + public static Map> getRolesForUsers(Start start, AppIdentifier appIdentifier, List userIds) + throws SQLException, StorageQueryException { + String QUERY = "SELECT user_id, role FROM " + getConfig(start).getUserRolesTable() + + " WHERE app_id = ? AND user_id IN ("+Utils.generateCommaSeperatedQuestionMarks(userIds.size())+") ;"; + + return execute(start, QUERY, pst -> { + pst.setString(1, appIdentifier.getAppId()); + for(int i = 0; i < userIds.size(); i++) { + pst.setString(2+i, userIds.get(i)); + } + }, result -> { + Map> rolesByUserId = new HashMap<>(); + while (result.next()) { + String userId = result.getString("user_id"); + if(!rolesByUserId.containsKey(userId)) { + rolesByUserId.put(userId, new ArrayList<>()); + } + rolesByUserId.get(userId).add(result.getString("role")); + } + return rolesByUserId; + }); + } + public static boolean deleteRoleForUser_Transaction(Start start, Connection con, TenantIdentifier tenantIdentifier, String userId, String role) throws SQLException, StorageQueryException { @@ -264,6 +316,25 @@ public static boolean doesRoleExist_transaction(Start start, Connection con, App }, ResultSet::next); } + public static List doesMultipleRoleExist_transaction(Start start, Connection con, AppIdentifier appIdentifier, + List roles) + throws SQLException, StorageQueryException { + String QUERY = "SELECT role FROM " + getConfig(start).getRolesTable() + + " WHERE app_id = ? AND role IN (" +Utils.generateCommaSeperatedQuestionMarks(roles.size())+ ") FOR UPDATE"; + return execute(con, QUERY, pst -> { + pst.setString(1, appIdentifier.getAppId()); + for (int i = 0; i < roles.size(); i++) { + pst.setString(2+i, roles.get(i)); + } + }, result -> { + List rolesFound = new ArrayList<>(); + while(result.next()){ + rolesFound.add(result.getString("role")); + } + return rolesFound; + }); + } + public static String[] getUsersForRole(Start start, TenantIdentifier tenantIdentifier, String role) throws SQLException, StorageQueryException { String QUERY = "SELECT user_id FROM " + getConfig(start).getUserRolesTable() diff --git a/src/test/java/io/supertokens/storage/postgresql/test/OneMillionUsersTest.java b/src/test/java/io/supertokens/storage/postgresql/test/OneMillionUsersTest.java index ea29a8f9..28799f22 100644 --- a/src/test/java/io/supertokens/storage/postgresql/test/OneMillionUsersTest.java +++ b/src/test/java/io/supertokens/storage/postgresql/test/OneMillionUsersTest.java @@ -19,11 +19,15 @@ import com.google.gson.JsonArray; import com.google.gson.JsonElement; import com.google.gson.JsonObject; +import com.google.gson.JsonParser; import io.supertokens.ActiveUsers; import io.supertokens.Main; import io.supertokens.ProcessState; import io.supertokens.authRecipe.AuthRecipe; import io.supertokens.authRecipe.UserPaginationContainer; +import io.supertokens.cronjobs.CronTaskTest; +import io.supertokens.cronjobs.Cronjobs; +import io.supertokens.cronjobs.bulkimport.ProcessBulkImportUsers; import io.supertokens.emailpassword.EmailPassword; import io.supertokens.emailpassword.ParsedFirebaseSCryptResponse; import io.supertokens.featureflag.EE_FEATURES; @@ -31,19 +35,23 @@ import io.supertokens.featureflag.FeatureFlagTestContent; import io.supertokens.passwordless.Passwordless; import io.supertokens.pluginInterface.RECIPE_ID; +import io.supertokens.pluginInterface.STORAGE_TYPE; import io.supertokens.pluginInterface.Storage; import io.supertokens.pluginInterface.authRecipe.AuthRecipeUserInfo; import io.supertokens.pluginInterface.authRecipe.LoginMethod; import io.supertokens.pluginInterface.authRecipe.sqlStorage.AuthRecipeSQLStorage; +import io.supertokens.pluginInterface.bulkimport.BulkImportStorage; import io.supertokens.pluginInterface.emailpassword.sqlStorage.EmailPasswordSQLStorage; import io.supertokens.pluginInterface.multitenancy.AppIdentifier; import io.supertokens.pluginInterface.multitenancy.TenantIdentifier; +import io.supertokens.pluginInterface.multitenancy.exceptions.TenantOrAppNotFoundException; import io.supertokens.pluginInterface.passwordless.sqlStorage.PasswordlessSQLStorage; import io.supertokens.pluginInterface.thirdparty.sqlStorage.ThirdPartySQLStorage; import io.supertokens.session.Session; import io.supertokens.session.info.SessionInformationHolder; import io.supertokens.storage.postgresql.Start; import io.supertokens.storage.postgresql.test.httpRequest.HttpRequestForTesting; +import io.supertokens.storage.postgresql.test.httpRequest.HttpResponseException; import io.supertokens.storageLayer.StorageLayer; import io.supertokens.thirdparty.ThirdParty; import io.supertokens.useridmapping.UserIdMapping; @@ -56,8 +64,12 @@ import org.junit.Test; import org.junit.rules.TestRule; +import java.io.IOException; +import java.net.SocketTimeoutException; import java.util.*; -import java.util.concurrent.*; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicLong; import java.util.function.Supplier; @@ -969,4 +981,239 @@ private static long measureTime(Supplier function) { // Calculate elapsed time in milliseconds return (endTime - startTime) / 1000000; // Convert to milliseconds } + + @Test + public void testWithOneMillionUsers() throws Exception { + Main main = startCronProcess(String.valueOf(NUM_THREADS)); + + int NUMBER_OF_USERS_TO_UPLOAD = 1000000; // million + + if (StorageLayer.getBaseStorage(main).getType() != STORAGE_TYPE.SQL || StorageLayer.isInMemDb(main)) { + return; + } + + // Create user roles before inserting bulk users + { + UserRoles.createNewRoleOrModifyItsPermissions(main, "role1", null); + UserRoles.createNewRoleOrModifyItsPermissions(main, "role2", null); + } + + // upload a bunch of users through the API + { + for (int i = 0; i < (NUMBER_OF_USERS_TO_UPLOAD / 10000); i++) { + JsonObject request = generateUsersJson(10000, i * 10000); // API allows 10k users upload at once + JsonObject response = uploadBulkImportUsersJson(main, request); + assertEquals("OK", response.get("status").getAsString()); + } + + } + + long processingStarted = System.currentTimeMillis(); + + // wait for the cron job to process them + // periodically check the remaining unprocessed users + // Note1: the cronjob starts the processing automatically + // Note2: the successfully processed users get deleted from the bulk_import_users table + { + long count = NUMBER_OF_USERS_TO_UPLOAD; + while(true) { + try { + JsonObject response = loadBulkImportUsersCountWithStatus(main, null); + assertEquals("OK", response.get("status").getAsString()); + count = response.get("count").getAsLong(); + int newUsersNumber = loadBulkImportUsersCountWithStatus(main, + BulkImportStorage.BULK_IMPORT_USER_STATUS.NEW).get("count").getAsInt(); + int processingUsersNumber = loadBulkImportUsersCountWithStatus(main, + BulkImportStorage.BULK_IMPORT_USER_STATUS.PROCESSING).get("count").getAsInt(); + int failedUsersNumber = loadBulkImportUsersCountWithStatus(main, + BulkImportStorage.BULK_IMPORT_USER_STATUS.FAILED).get("count").getAsInt(); + count = newUsersNumber + processingUsersNumber; + + if (count == 0) { + break; + } + } catch (Exception e) { + if(e instanceof SocketTimeoutException) { + //ignore + } else { + throw e; + } + } + Thread.sleep(5000); + } + } + + long processingFinished = System.currentTimeMillis(); + System.out.println("Processed " + NUMBER_OF_USERS_TO_UPLOAD + " users in " + (processingFinished - processingStarted) / 1000 + + " seconds ( or " + (processingFinished - processingStarted) / 60000 + " minutes)"); + + // after processing finished, make sure every user got processed correctly + { + int failedImportedUsersNumber = loadBulkImportUsersCountWithStatus(main, BulkImportStorage.BULK_IMPORT_USER_STATUS.FAILED).get("count").getAsInt(); + int usersInCore = loadUsersCount(main).get("count").getAsInt(); + assertEquals(NUMBER_OF_USERS_TO_UPLOAD, usersInCore + failedImportedUsersNumber); + assertEquals(NUMBER_OF_USERS_TO_UPLOAD, usersInCore); + } + } + + private static JsonObject loadBulkImportUsersCountWithStatus(Main main, BulkImportStorage.BULK_IMPORT_USER_STATUS status) + throws HttpResponseException, IOException { + Map params = new HashMap<>(); + if(status!= null) { + params.put("status", status.name()); + } + return HttpRequestForTesting.sendGETRequest(main, "", + "http://localhost:3567/bulk-import/users/count", + params, 10000, 10000, null, SemVer.v5_2.get(), null); + } + + private static JsonObject loadBulkImportUsersWithStatus(Main main, BulkImportStorage.BULK_IMPORT_USER_STATUS status) + throws HttpResponseException, IOException { + Map params = new HashMap<>(); + if(status!= null) { + params.put("status", status.name()); + } + return HttpRequestForTesting.sendGETRequest(main, "", + "http://localhost:3567/bulk-import/users", + params, 10000, 10000, null, SemVer.v5_2.get(), null); + } + + private static JsonObject loadUsersCount(Main main) throws HttpResponseException, IOException { + Map params = new HashMap<>(); + + return HttpRequestForTesting.sendGETRequest(main, "", + "http://localhost:3567/users/count", + params, 10000, 10000, null, SemVer.v5_2.get(), null); + } + + private static JsonObject generateUsersJson(int numberOfUsers, int startIndex) { + JsonObject userJsonObject = new JsonObject(); + JsonParser parser = new JsonParser(); + + JsonArray usersArray = new JsonArray(); + for (int i = 0; i < numberOfUsers; i++) { + JsonObject user = new JsonObject(); + + user.addProperty("externalUserId", UUID.randomUUID().toString()); + user.add("userMetadata", parser.parse("{\"key1\":"+ UUID.randomUUID().toString() + ",\"key2\":{\"key3\":\"value3\"}}")); + user.add("userRoles", parser.parse( + "[{\"role\":\"role1\", \"tenantIds\": [\"public\"]},{\"role\":\"role2\", \"tenantIds\": [\"public\"]}]")); + user.add("totpDevices", parser.parse("[{\"secretKey\":\"secretKey\",\"deviceName\":\"deviceName\"}]")); + + //JsonArray tenanatIds = parser.parse("[\"public\", \"t1\"]").getAsJsonArray(); + JsonArray tenanatIds = parser.parse("[\"public\"]").getAsJsonArray(); + String email = " johndoe+" + (i + startIndex) + "@gmail.com "; + + Random random = new Random(); + + JsonArray loginMethodsArray = new JsonArray(); + //if(random.nextInt(2) == 0){ + loginMethodsArray.add(createEmailLoginMethod(email, tenanatIds)); + //} + if(random.nextInt(2) == 0){ + loginMethodsArray.add(createThirdPartyLoginMethod(email, tenanatIds)); + } + if(random.nextInt(2) == 0){ + loginMethodsArray.add(createPasswordlessLoginMethod(email, tenanatIds, "+910000" + (startIndex + i))); + } + if(loginMethodsArray.size() == 0) { + int methodNumber = random.nextInt(3); + switch (methodNumber) { + case 0: + loginMethodsArray.add(createEmailLoginMethod(email, tenanatIds)); + break; + case 1: + loginMethodsArray.add(createThirdPartyLoginMethod(email, tenanatIds)); + break; + case 2: + loginMethodsArray.add(createPasswordlessLoginMethod(email, tenanatIds, "+911000" + (startIndex + i))); + break; + } + } + user.add("loginMethods", loginMethodsArray); + + usersArray.add(user); + } + + userJsonObject.add("users", usersArray); + return userJsonObject; + } + + private static JsonObject createEmailLoginMethod(String email, JsonArray tenantIds) { + JsonObject loginMethod = new JsonObject(); + loginMethod.add("tenantIds", tenantIds); + loginMethod.addProperty("email", email); + loginMethod.addProperty("recipeId", "emailpassword"); + loginMethod.addProperty("passwordHash", + "$argon2d$v=19$m=12,t=3,p=1$aGI4enNvMmd0Zm0wMDAwMA$r6p7qbr6HD+8CD7sBi4HVw"); + loginMethod.addProperty("hashingAlgorithm", "argon2"); + loginMethod.addProperty("isVerified", true); + loginMethod.addProperty("isPrimary", true); + loginMethod.addProperty("timeJoinedInMSSinceEpoch", 0); + return loginMethod; + } + + private static JsonObject createThirdPartyLoginMethod(String email, JsonArray tenantIds) { + JsonObject loginMethod = new JsonObject(); + loginMethod.add("tenantIds", tenantIds); + loginMethod.addProperty("recipeId", "thirdparty"); + loginMethod.addProperty("email", email); + loginMethod.addProperty("thirdPartyId", "google"); + loginMethod.addProperty("thirdPartyUserId", String.valueOf(email.hashCode())); + loginMethod.addProperty("isVerified", true); + loginMethod.addProperty("isPrimary", false); + loginMethod.addProperty("timeJoinedInMSSinceEpoch", 0); + return loginMethod; + } + + private static JsonObject createPasswordlessLoginMethod(String email, JsonArray tenantIds, String phoneNumber) { + JsonObject loginMethod = new JsonObject(); + loginMethod.add("tenantIds", tenantIds); + loginMethod.addProperty("email", email); + loginMethod.addProperty("recipeId", "passwordless"); + loginMethod.addProperty("phoneNumber", phoneNumber); + loginMethod.addProperty("isVerified", true); + loginMethod.addProperty("isPrimary", false); + loginMethod.addProperty("timeJoinedInMSSinceEpoch", 0); + return loginMethod; + } + + private void setFeatureFlags(Main main, EE_FEATURES[] features) { + FeatureFlagTestContent.getInstance(main).setKeyValue(FeatureFlagTestContent.ENABLED_FEATURES, features); + } + + + private static JsonObject uploadBulkImportUsersJson(Main main, JsonObject request) throws IOException, + HttpResponseException { + return HttpRequestForTesting.sendJsonPOSTRequest(main, "", + "http://localhost:3567/bulk-import/users", + request, 1000, 10000, null, SemVer.v5_2.get(), null); + } + + private Main startCronProcess(String parallelism) throws IOException, InterruptedException, + TenantOrAppNotFoundException { + return startCronProcess(parallelism, 5*60); + } + + private Main startCronProcess(String parallelism, int intervalInSeconds) throws IOException, InterruptedException, TenantOrAppNotFoundException { + String[] args = { "../" }; + + // set processing thread number + Utils.setValueInConfig("bulk_migration_parallelism", parallelism); + + TestingProcessManager.TestingProcess process = TestingProcessManager.start(args, false); + Main main = process.getProcess(); + setFeatureFlags(main, new EE_FEATURES[] { + EE_FEATURES.ACCOUNT_LINKING, EE_FEATURES.MULTI_TENANCY, EE_FEATURES.MFA }); + // We are setting a non-zero initial wait for tests to avoid race condition with the beforeTest process that deletes data in the storage layer + CronTaskTest.getInstance(main).setInitialWaitTimeInSeconds(ProcessBulkImportUsers.RESOURCE_KEY, 5); + CronTaskTest.getInstance(main).setIntervalInSeconds(ProcessBulkImportUsers.RESOURCE_KEY, intervalInSeconds); + + process.startProcess(); + assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STARTED)); + + Cronjobs.addCronjob(main, (ProcessBulkImportUsers) main.getResourceDistributor().getResource(new TenantIdentifier(null, null, null), ProcessBulkImportUsers.RESOURCE_KEY)); + return main; + } + }