Skip to content

Commit

Permalink
Merge branch 'mfa-cleanup' into mfa-multitenancy
Browse files Browse the repository at this point in the history
  • Loading branch information
sattvikc committed Oct 12, 2023
2 parents afe0299 + d30cd18 commit 47ce839
Show file tree
Hide file tree
Showing 5 changed files with 8 additions and 205 deletions.
78 changes: 2 additions & 76 deletions src/main/java/io/supertokens/storage/postgresql/Start.java
Original file line number Diff line number Diff line change
Expand Up @@ -761,12 +761,7 @@ public boolean isUserIdBeingUsedInNonAuthRecipe(AppIdentifier appIdentifier, Str
} else if (className.equals(ActiveUsersStorage.class.getName())) {
return ActiveUsersQueries.getLastActiveByUserId(this, appIdentifier, userId) != null;
} else if (className.equals(MfaStorage.class.getName())) {
try {
MultitenancyQueries.getAllTenants(this);
return MfaQueries.listFactors(this, appIdentifier, userId).length > 0;
} catch (SQLException e) {
throw new StorageQueryException(e);
}
return false; // there is nothing here
} else {
throw new IllegalStateException("ClassName: " + className + " is not part of NonAuthRecipeStorage");
}
Expand Down Expand Up @@ -865,11 +860,7 @@ public void addInfoToNonAuthRecipesBasedOnUserId(TenantIdentifier tenantIdentifi
throw new StorageQueryException(e);
}
} else if (className.equals(MfaStorage.class.getName())) {
try {
MfaQueries.enableFactor(this, tenantIdentifier, userId, "emailpassword");
} catch (SQLException e) {
throw new StorageQueryException(e);
}
/* nothing to be added here */
} else {
throw new IllegalStateException("ClassName: " + className + " is not part of NonAuthRecipeStorage");
}
Expand Down Expand Up @@ -2838,71 +2829,6 @@ public int removeExpiredCodes(TenantIdentifier tenantIdentifier, long expiredBef
}
}

// MFA recipe:
@Override
public boolean enableFactor(TenantIdentifier tenantIdentifier, String userId, String factor)
throws StorageQueryException {
try {
int insertedCount = MfaQueries.enableFactor(this, tenantIdentifier, userId, factor);
if (insertedCount == 0) {
return false;
}
return true;
} catch (SQLException e) {
throw new StorageQueryException(e);
}
}

@Override
public String[] listFactors(TenantIdentifier tenantIdentifier, String userId)
throws StorageQueryException {
try {
return MfaQueries.listFactors(this, tenantIdentifier, userId);
} catch (SQLException e) {
throw new StorageQueryException(e);
}
}

@Override
public boolean disableFactor(TenantIdentifier tenantIdentifier, String userId, String factor)
throws StorageQueryException {
try {
int deletedCount = MfaQueries.disableFactor(this, tenantIdentifier, userId, factor);
if (deletedCount == 0) {
return false;
}
return true;
} catch (SQLException e) {
throw new StorageQueryException(e);
}
}

@Override
public boolean deleteMfaInfoForUser_Transaction(TransactionConnection con, AppIdentifier appIdentifier, String userId) throws StorageQueryException {
try {
int deletedCount = MfaQueries.deleteUser_Transaction(this, (Connection) con.getConnection(), appIdentifier, userId);
if (deletedCount == 0) {
return false;
}
return true;
} catch (SQLException e) {
throw new StorageQueryException(e);
}
}

@Override
public boolean deleteMfaInfoForUser(TenantIdentifier tenantIdentifier, String userId) throws StorageQueryException {
try {
int deletedCount = MfaQueries.deleteUser(this, tenantIdentifier, userId);
if (deletedCount == 0) {
return false;
}
return true;
} catch (SQLException e) {
throw new StorageQueryException(e);
}
}

@Override
public Set<String> getValidFieldsInConfig() {
return PostgreSQLConfig.getValidFields();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -302,10 +302,6 @@ public String getTotpUsedCodesTable() {
return addSchemaAndPrefixToTableName("totp_used_codes");
}

public String getMfaUserFactorsTable() {
return addSchemaAndPrefixToTableName("mfa_user_factors");
}

private String addSchemaAndPrefixToTableName(String tableName) {
return addSchemaToTableName(postgresql_table_names_prefix + tableName);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,35 +106,11 @@ public static int countUsersEnabledTotpAndActiveSince(Start start, AppIdentifier
}

public static int countUsersEnabledMfa(Start start, AppIdentifier appIdentifier) throws SQLException, StorageQueryException {
String QUERY = "SELECT COUNT(*) as total FROM (SELECT DISTINCT user_id FROM " + Config.getConfig(start).getMfaUserFactorsTable() + " WHERE app_id = ?) AS app_mfa_users";

return execute(start, QUERY, pst -> {
pst.setString(1, appIdentifier.getAppId());
}, result -> {
if (result.next()) {
return result.getInt("total");
}
return 0;
});
return 0; // TODO
}

public static int countUsersEnabledMfaAndActiveSince(Start start, AppIdentifier appIdentifier, long sinceTime) throws SQLException, StorageQueryException {
// Find unique users from mfa_user_factors table and join with user_last_active table
String QUERY = "SELECT COUNT(*) as total FROM (SELECT DISTINCT user_id FROM " + Config.getConfig(start).getMfaUserFactorsTable() + ") AS mfa_users "
+ "INNER JOIN " + Config.getConfig(start).getUserLastActiveTable() + " AS user_last_active "
+ "ON mfa_users.user_id = user_last_active.user_id "
+ "WHERE user_last_active.app_id = ? "
+ "AND user_last_active.last_active_time >= ?";

return execute(start, QUERY, pst -> {
pst.setString(1, appIdentifier.getAppId());
pst.setLong(2, sinceTime);
}, result -> {
if (result.next()) {
return result.getInt("total");
}
return 0;
});
return 0; // TODO
}

public static int updateUserLastActive(Start start, AppIdentifier appIdentifier, String userId)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import io.supertokens.storage.postgresql.ConnectionPool;
import io.supertokens.storage.postgresql.Start;
import io.supertokens.storage.postgresql.config.Config;
import io.supertokens.storage.postgresql.queries.GeneralQueries.AccountLinkingInfo;
import io.supertokens.storage.postgresql.utils.Utils;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;
Expand Down Expand Up @@ -516,11 +517,6 @@ public static void createTablesIfNotExists(Start start) throws SQLException, Sto
update(start, TOTPQueries.getQueryToCreateTenantIdIndexForUsedCodesTable(start), NO_OP_SETTER);
}

if (!doesTableExists(start, Config.getConfig(start).getMfaUserFactorsTable())) {
getInstance(start).addState(CREATING_NEW_TABLE, null);
update(start, MfaQueries.getQueryToCreateUserFactorsTable(start), NO_OP_SETTER);
}

} catch (Exception e) {
if (e.getMessage().contains("schema") && e.getMessage().contains("does not exist")
&& numberOfRetries < 1) {
Expand Down Expand Up @@ -589,8 +585,9 @@ public static void deleteAllTables(Start start) throws SQLException, StorageQuer
+ getConfig(start).getUserRolesTable() + ","
+ getConfig(start).getDashboardUsersTable() + ","
+ getConfig(start).getDashboardSessionsTable() + ","
+ getConfig(start).getTotpUsedCodesTable() + "," + getConfig(start).getTotpUserDevicesTable() + ","
+ getConfig(start).getTotpUsersTable() + "," + getConfig(start).getMfaUserFactorsTable();
+ getConfig(start).getTotpUsedCodesTable() + ","
+ getConfig(start).getTotpUserDevicesTable() + ","
+ getConfig(start).getTotpUsersTable();
update(start, DROP_QUERY, NO_OP_SETTER);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,96 +31,4 @@
import static io.supertokens.storage.postgresql.QueryExecutorTemplate.update;

public class MfaQueries {
public static String getQueryToCreateUserFactorsTable(Start start) {
return "CREATE TABLE IF NOT EXISTS " + Config.getConfig(start).getMfaUserFactorsTable() + " ("
+ "app_id VARCHAR(64) DEFAULT 'public',"
+ "tenant_id VARCHAR(64) DEFAULT 'public',"
+ "user_id VARCHAR(128) NOT NULL,"
+ "factor_id VARCHAR(64) NOT NULL,"
+ "PRIMARY KEY (app_id, tenant_id, user_id, factor_id),"
+ "FOREIGN KEY (app_id, tenant_id)"
+ "REFERENCES " + Config.getConfig(start).getTenantsTable() + " (app_id, tenant_id) ON DELETE CASCADE);";
}

public static int enableFactor(Start start, TenantIdentifier tenantIdentifier, String userId, String factorId)
throws StorageQueryException, SQLException {
String QUERY = "INSERT INTO " + Config.getConfig(start).getMfaUserFactorsTable() + " (app_id, tenant_id, user_id, factor_id) VALUES (?, ?, ?, ?) ON CONFLICT DO NOTHING";

return update(start, QUERY, pst -> {
pst.setString(1, tenantIdentifier.getAppId());
pst.setString(2, tenantIdentifier.getTenantId());
pst.setString(3, userId);
pst.setString(4, factorId);
});
}


public static String[] listFactors(Start start, TenantIdentifier tenantIdentifier, String userId)
throws StorageQueryException, SQLException {
String QUERY = "SELECT factor_id FROM " + Config.getConfig(start).getMfaUserFactorsTable() + " WHERE app_id = ? AND tenant_id = ? AND user_id = ?";

return execute(start, QUERY, pst -> {
pst.setString(1, tenantIdentifier.getAppId());
pst.setString(2, tenantIdentifier.getTenantId());
pst.setString(3, userId);
}, result -> {
List<String> factors = new ArrayList<>();
while (result.next()) {
factors.add(result.getString("factor_id"));
}

return factors.toArray(String[]::new);
});
}

public static String[] listFactors(Start start, AppIdentifier appIdentifier, String userId)
throws StorageQueryException, SQLException {
String QUERY = "SELECT factor_id FROM " + Config.getConfig(start).getMfaUserFactorsTable() + " WHERE app_id = ? AND user_id = ?";

return execute(start, QUERY, pst -> {
pst.setString(1, appIdentifier.getAppId());
pst.setString(2, userId);
}, result -> {
List<String> factors = new ArrayList<>();
while (result.next()) {
factors.add(result.getString("factor_id"));
}

return factors.toArray(String[]::new);
});
}

public static int disableFactor(Start start, TenantIdentifier tenantIdentifier, String userId, String factorId)
throws StorageQueryException, SQLException {
String QUERY = "DELETE FROM " + Config.getConfig(start).getMfaUserFactorsTable() + " WHERE app_id = ? AND tenant_id = ? AND user_id = ? AND factor_id = ?";

return update(start, QUERY, pst -> {
pst.setString(1, tenantIdentifier.getAppId());
pst.setString(2, tenantIdentifier.getTenantId());
pst.setString(3, userId);
pst.setString(4, factorId);
});
}

public static int deleteUser_Transaction(Start start, Connection sqlCon, AppIdentifier appIdentifier, String userId)
throws StorageQueryException, SQLException {
String QUERY = "DELETE FROM " + Config.getConfig(start).getMfaUserFactorsTable() + " WHERE app_id = ? AND user_id = ?";

return update(sqlCon, QUERY, pst -> {
pst.setString(1, appIdentifier.getAppId());
pst.setString(2, userId);
});
}

public static int deleteUser(Start start, TenantIdentifier tenantIdentifier, String userId)
throws StorageQueryException, SQLException {
String QUERY = "DELETE FROM " + Config.getConfig(start).getMfaUserFactorsTable() + " WHERE app_id = ? AND tenant_id = ? AND user_id = ?";

return update(start, QUERY, pst -> {
pst.setString(1, tenantIdentifier.getAppId());
pst.setString(2, tenantIdentifier.getTenantId());
pst.setString(3, userId);
});
}

}

0 comments on commit 47ce839

Please sign in to comment.