Skip to content

Commit

Permalink
Mfa inmemory (#874)
Browse files Browse the repository at this point in the history
* fix: add createdat to totp device

* fix: inmemory changes for mfa

* fix: mfa stats queries
  • Loading branch information
sattvikc authored Nov 6, 2023
1 parent c10aff7 commit 6cd3c59
Show file tree
Hide file tree
Showing 9 changed files with 309 additions and 15 deletions.
14 changes: 11 additions & 3 deletions src/main/java/io/supertokens/inmemorydb/Start.java
Original file line number Diff line number Diff line change
Expand Up @@ -2961,11 +2961,19 @@ public UserIdMapping[] getUserIdMapping_Transaction(TransactionConnection con, A

@Override
public int getUsersCountWithMoreThanOneLoginMethodOrTOTPEnabled(AppIdentifier appIdentifier) throws StorageQueryException {
return 0; // TODO
try {
return GeneralQueries.getUsersCountWithMoreThanOneLoginMethodOrTOTPEnabled(this, appIdentifier);
} catch (SQLException e) {
throw new StorageQueryException(e);
}
}

@Override
public int countUsersThatHaveMoreThanOneLoginMethodOrTOTPEnabledAndActiveSince(AppIdentifier appIdentifier, long timestamp) throws StorageQueryException {
return 0; // TODO
public int countUsersThatHaveMoreThanOneLoginMethodOrTOTPEnabledAndActiveSince(AppIdentifier appIdentifier, long sinceTime) throws StorageQueryException {
try {
return ActiveUsersQueries.countUsersThatHaveMoreThanOneLoginMethodOrTOTPEnabledAndActiveSince(this, appIdentifier, sinceTime);
} catch (SQLException e) {
throw new StorageQueryException(e);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,14 @@ public String getTenantConfigsTable() {
return "tenant_configs";
}

public String getTenantFirstFactorsTable() {
return "tenant_first_factors";
}

public String getTenantDefaultRequiredFactorIdsTable() {
return "tenant_default_required_factor_ids";
}

public String getTenantThirdPartyProvidersTable() {
return "tenant_thirdparty_providers";
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,4 +109,41 @@ public static void deleteUserActive_Transaction(Connection con, Start start, App
pst.setString(2, userId);
});
}

public static int countUsersThatHaveMoreThanOneLoginMethodOrTOTPEnabledAndActiveSince(Start start, AppIdentifier appIdentifier, long sinceTime)
throws SQLException, StorageQueryException {
// TODO: Active users are present only on public tenant and MFA users may be present on different storages
String QUERY =
"SELECT COUNT (DISTINCT user_id) as c FROM ("
+ " " // users with more than one login method
+ " SELECT primary_or_recipe_user_id AS user_id FROM ("
+ " SELECT COUNT(user_id) as num_login_methods, app_id, primary_or_recipe_user_id"
+ " FROM " + Config.getConfig(start).getAppIdToUserIdTable()
+ " WHERE app_id = ? AND primary_or_recipe_user_id IN ("
+ " SELECT user_id FROM " + Config.getConfig(start).getUserLastActiveTable()
+ " WHERE app_id = ? AND last_active_time >= ?"
+ " )"
+ " GROUP BY app_id, primary_or_recipe_user_id"
+ " ) AS nloginmethods"
+ " WHERE num_login_methods > 1"
+ " UNION" // TOTP users
+ " SELECT user_id FROM " + Config.getConfig(start).getTotpUsersTable()
+ " WHERE app_id = ? AND user_id IN ("
+ " SELECT user_id FROM " + Config.getConfig(start).getUserLastActiveTable()
+ " WHERE app_id = ? AND last_active_time >= ?"
+ " )"
+ " "
+ ") AS all_users";

return execute(start, QUERY, pst -> {
pst.setString(1, appIdentifier.getAppId());
pst.setString(2, appIdentifier.getAppId());
pst.setLong(3, sinceTime);
pst.setString(4, appIdentifier.getAppId());
pst.setString(5, appIdentifier.getAppId());
pst.setLong(6, sinceTime);
}, result -> {
return result.next() ? result.getInt("c") : 0;
});
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,21 @@ public static void createTablesIfNotExists(Start start, Main main) throws SQLExc
update(start, MultitenancyQueries.getQueryToCreateTenantConfigsTable(start), NO_OP_SETTER);
}

if (!doesTableExists(start, Config.getConfig(start).getTenantFirstFactorsTable())) {
getInstance(main).addState(CREATING_NEW_TABLE, null);
update(start, MultitenancyQueries.getQueryToCreateFirstFactorsTable(start), NO_OP_SETTER);
}

if (!doesTableExists(start, Config.getConfig(start).getTenantDefaultRequiredFactorIdsTable())) {
getInstance(main).addState(CREATING_NEW_TABLE, null);
update(start, MultitenancyQueries.getQueryToCreateDefaultRequiredFactorIdsTable(start), NO_OP_SETTER);

// index
update(start,
MultitenancyQueries.getQueryToCreateOrderIndexForDefaultRequiredFactorIdsTable(start),
NO_OP_SETTER);
}

if (!doesTableExists(start, Config.getConfig(start).getTenantThirdPartyProvidersTable())) {
getInstance(main).addState(CREATING_NEW_TABLE, null);
update(start, MultitenancyQueries.getQueryToCreateTenantThirdPartyProvidersTable(start),
Expand Down Expand Up @@ -1511,6 +1526,32 @@ public static int getUsersCountWithMoreThanOneLoginMethod(Start start, AppIdenti
});
}

public static int getUsersCountWithMoreThanOneLoginMethodOrTOTPEnabled(Start start, AppIdentifier appIdentifier)
throws SQLException, StorageQueryException {
String QUERY =
"SELECT COUNT (DISTINCT user_id) as c FROM ("
+ " " // Users with number of login methods > 1
+ " SELECT primary_or_recipe_user_id AS user_id FROM ("
+ " SELECT COUNT(user_id) as num_login_methods, app_id, primary_or_recipe_user_id"
+ " FROM " + getConfig(start).getAppIdToUserIdTable()
+ " WHERE app_id = ? "
+ " GROUP BY app_id, primary_or_recipe_user_id"
+ " ) AS nloginmethods"
+ " WHERE num_login_methods > 1"
+ " UNION" // TOTP users
+ " SELECT user_id FROM " + getConfig(start).getTotpUsersTable()
+ " WHERE app_id = ?"
+ " "
+ ") AS all_users";

return execute(start, QUERY, pst -> {
pst.setString(1, appIdentifier.getAppId());
pst.setString(2, appIdentifier.getAppId());
}, result -> {
return result.next() ? result.getInt("c") : 0;
});
}

public static boolean checkIfUsesAccountLinking(Start start, AppIdentifier appIdentifier)
throws SQLException, StorageQueryException {
String QUERY = "SELECT 1 FROM "
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import io.supertokens.inmemorydb.Start;
import io.supertokens.inmemorydb.config.Config;
import io.supertokens.inmemorydb.queries.multitenancy.MfaSqlHelper;
import io.supertokens.inmemorydb.queries.multitenancy.TenantConfigSQLHelper;
import io.supertokens.inmemorydb.queries.multitenancy.ThirdPartyProviderClientSQLHelper;
import io.supertokens.inmemorydb.queries.multitenancy.ThirdPartyProviderSQLHelper;
Expand Down Expand Up @@ -46,11 +47,53 @@ static String getQueryToCreateTenantConfigsTable(Start start) {
+ "email_password_enabled BOOLEAN,"
+ "passwordless_enabled BOOLEAN,"
+ "third_party_enabled BOOLEAN,"
+ "totp_enabled BOOLEAN,"
+ "has_first_factors BOOLEAN DEFAULT FALSE,"
+ "has_default_required_factor_ids BOOLEAN DEFAULT FALSE,"
+ "PRIMARY KEY (connection_uri_domain, app_id, tenant_id)"
+ ");";
// @formatter:on
}

public static String getQueryToCreateFirstFactorsTable(Start start) {
String tableName = Config.getConfig(start).getTenantFirstFactorsTable();
// @formatter:off
return "CREATE TABLE IF NOT EXISTS " + tableName + " ("
+ "connection_uri_domain VARCHAR(256) DEFAULT '',"
+ "app_id VARCHAR(64) DEFAULT 'public',"
+ "tenant_id VARCHAR(64) DEFAULT 'public',"
+ "factor_id VARCHAR(128),"
+ "PRIMARY KEY (connection_uri_domain, app_id, tenant_id, factor_id),"
+ "FOREIGN KEY (connection_uri_domain, app_id, tenant_id)"
+ " REFERENCES " + Config.getConfig(start).getTenantConfigsTable()
+ " (connection_uri_domain, app_id, tenant_id) ON DELETE CASCADE"
+ ");";
// @formatter:on
}

public static String getQueryToCreateDefaultRequiredFactorIdsTable(Start start) {
String tableName = Config.getConfig(start).getTenantDefaultRequiredFactorIdsTable();
// @formatter:off
return "CREATE TABLE IF NOT EXISTS " + tableName + " ("
+ "connection_uri_domain VARCHAR(256) DEFAULT '',"
+ "app_id VARCHAR(64) DEFAULT 'public',"
+ "tenant_id VARCHAR(64) DEFAULT 'public',"
+ "factor_id VARCHAR(128),"
+ "order_idx INTEGER NOT NULL,"
+ "PRIMARY KEY (connection_uri_domain, app_id, tenant_id, factor_id),"
+ "FOREIGN KEY (connection_uri_domain, app_id, tenant_id)"
+ " REFERENCES " + Config.getConfig(start).getTenantConfigsTable()
+ " (connection_uri_domain, app_id, tenant_id) ON DELETE CASCADE,"
+ "UNIQUE (connection_uri_domain, app_id, tenant_id, order_idx)"
+ ");";
// @formatter:on
}

public static String getQueryToCreateOrderIndexForDefaultRequiredFactorIdsTable(Start start) {
return "CREATE INDEX IF NOT EXISTS tenant_default_required_factor_ids_tenant_id_index ON "
+ Config.getConfig(start).getTenantDefaultRequiredFactorIdsTable() + " (order_idx ASC);";
}

static String getQueryToCreateTenantThirdPartyProvidersTable(Start start) {
String tenantThirdPartyProvidersTable = Config.getConfig(start).getTenantThirdPartyProvidersTable();
// @formatter:off
Expand Down Expand Up @@ -114,6 +157,9 @@ private static void executeCreateTenantQueries(Start start, Connection sqlCon, T
ThirdPartyProviderClientSQLHelper.create(start, sqlCon, tenantConfig, provider, providerClient);
}
}

MfaSqlHelper.createFirstFactors(start, sqlCon, tenantConfig.tenantIdentifier, tenantConfig.firstFactors);
MfaSqlHelper.createDefaultRequiredFactorIds(start, sqlCon, tenantConfig.tenantIdentifier, tenantConfig.defaultRequiredFactorIds);
}

public static void createTenantConfig(Start start, TenantConfig tenantConfig) throws StorageQueryException, StorageTransactionLogicException {
Expand Down Expand Up @@ -192,7 +238,13 @@ public static TenantConfig[] getAllTenants(Start start) throws StorageQueryExcep
// Map (tenantIdentifier) -> thirdPartyId -> provider
HashMap<TenantIdentifier, HashMap<String, ThirdPartyConfig.Provider>> providerMap = ThirdPartyProviderSQLHelper.selectAll(start, providerClientsMap);

return TenantConfigSQLHelper.selectAll(start, providerMap);
// Map (tenantIdentifier) -> firstFactors
HashMap<TenantIdentifier, String[]> firstFactorsMap = MfaSqlHelper.selectAllFirstFactors(start);

// Map (tenantIdentifier) -> defaultRequiredFactorIds
HashMap<TenantIdentifier, String[]> defaultRequiredFactorIdsMap = MfaSqlHelper.selectAllDefaultRequiredFactorIds(start);

return TenantConfigSQLHelper.selectAll(start, providerMap, firstFactorsMap, defaultRequiredFactorIdsMap);
} catch (SQLException throwables) {
throw new StorageQueryException(throwables);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ public static String getQueryToCreateUserDevicesTable(Start start) {
+ "period INTEGER NOT NULL,"
+ "skew INTEGER NOT NULL,"
+ "verified BOOLEAN NOT NULL,"
+ "created_at BIGINT UNSIGNED NOT NULL,"
+ "PRIMARY KEY (app_id, user_id, device_name),"
+ "FOREIGN KEY (app_id, user_id) REFERENCES " + Config.getConfig(start).getTotpUsersTable()
+ " (app_id, user_id) ON DELETE CASCADE"
Expand Down Expand Up @@ -85,7 +86,7 @@ private static int insertUser_Transaction(Start start, Connection con, AppIdenti
private static int insertDevice_Transaction(Start start, Connection con, AppIdentifier appIdentifier, TOTPDevice device)
throws SQLException, StorageQueryException {
String QUERY = "INSERT INTO " + Config.getConfig(start).getTotpUserDevicesTable()
+ " (app_id, user_id, device_name, secret_key, period, skew, verified) VALUES (?, ?, ?, ?, ?, ?, ?)";
+ " (app_id, user_id, device_name, secret_key, period, skew, verified, created_at) VALUES (?, ?, ?, ?, ?, ?, ?, ?)";

return update(con, QUERY, pst -> {
pst.setString(1, appIdentifier.getAppId());
Expand All @@ -95,6 +96,7 @@ private static int insertDevice_Transaction(Start start, Connection con, AppIden
pst.setInt(5, device.period);
pst.setInt(6, device.skew);
pst.setBoolean(7, device.verified);
pst.setLong(8, device.createdAt);
});
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
/*
* Copyright (c) 2023, VRAI Labs and/or its affiliates. All rights reserved.
*
* This software is licensed under the Apache License, Version 2.0 (the
* "License") as published by the Apache Software Foundation.
*
* You may not use this file except in compliance with the License. You may
* obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*/

package io.supertokens.inmemorydb.queries.multitenancy;

import io.supertokens.inmemorydb.Start;
import io.supertokens.inmemorydb.config.Config;
import io.supertokens.pluginInterface.exceptions.StorageQueryException;
import io.supertokens.pluginInterface.multitenancy.TenantIdentifier;

import java.sql.Connection;
import java.sql.SQLException;
import java.util.*;

import static io.supertokens.inmemorydb.QueryExecutorTemplate.execute;
import static io.supertokens.inmemorydb.QueryExecutorTemplate.update;

public class MfaSqlHelper {
public static HashMap<TenantIdentifier, String[]> selectAllFirstFactors(Start start)
throws SQLException, StorageQueryException {
String QUERY = "SELECT connection_uri_domain, app_id, tenant_id, factor_id FROM "
+ Config.getConfig(start).getTenantFirstFactorsTable() + ";";
return execute(start, QUERY, pst -> {}, result -> {
HashMap<TenantIdentifier, List<String>> firstFactors = new HashMap<>();

while (result.next()) {
TenantIdentifier tenantIdentifier = new TenantIdentifier(result.getString("connection_uri_domain"), result.getString("app_id"), result.getString("tenant_id"));
if (!firstFactors.containsKey(tenantIdentifier)) {
firstFactors.put(tenantIdentifier, new ArrayList<>());
}

firstFactors.get(tenantIdentifier).add(result.getString("factor_id"));
}

HashMap<TenantIdentifier, String[]> finalResult = new HashMap<>();
for (TenantIdentifier tenantIdentifier : firstFactors.keySet()) {
finalResult.put(tenantIdentifier, firstFactors.get(tenantIdentifier).toArray(new String[0]));
}

return finalResult;
});
}

public static HashMap<TenantIdentifier, String[]> selectAllDefaultRequiredFactorIds(Start start)
throws SQLException, StorageQueryException {
String QUERY = "SELECT connection_uri_domain, app_id, tenant_id, factor_id, order_idx FROM "
+ Config.getConfig(start).getTenantDefaultRequiredFactorIdsTable() + " ORDER BY order_idx ASC;";
return execute(start, QUERY, pst -> {}, result -> {
HashMap<TenantIdentifier, List<String>> defaultRequiredFactors = new HashMap<>();

while (result.next()) {
TenantIdentifier tenantIdentifier = new TenantIdentifier(result.getString("connection_uri_domain"),
result.getString("app_id"), result.getString("tenant_id"));
if (!defaultRequiredFactors.containsKey(tenantIdentifier)) {
defaultRequiredFactors.put(tenantIdentifier, new ArrayList<>());
}

defaultRequiredFactors.get(tenantIdentifier).add(result.getString("factor_id"));
}

HashMap<TenantIdentifier, String[]> finalResult = new HashMap<>();
for (TenantIdentifier tenantIdentifier : defaultRequiredFactors.keySet()) {
finalResult.put(tenantIdentifier, defaultRequiredFactors.get(tenantIdentifier).toArray(new String[0]));
}

return finalResult;
});
}

public static void createFirstFactors(Start start, Connection sqlCon, TenantIdentifier tenantIdentifier, String[] firstFactors)
throws SQLException, StorageQueryException {
if (firstFactors == null || firstFactors.length == 0) {
return;
}

String QUERY = "INSERT INTO " + Config.getConfig(start).getTenantFirstFactorsTable() + "(connection_uri_domain, app_id, tenant_id, factor_id) VALUES (?, ?, ?, ?);";
for (String factorId : new HashSet<>(Arrays.asList(firstFactors))) {
update(sqlCon, QUERY, pst -> {
pst.setString(1, tenantIdentifier.getConnectionUriDomain());
pst.setString(2, tenantIdentifier.getAppId());
pst.setString(3, tenantIdentifier.getTenantId());
pst.setString(4, factorId);
});
}
}

public static void createDefaultRequiredFactorIds(Start start, Connection sqlCon, TenantIdentifier tenantIdentifier, String[] defaultRequiredFactorIds)
throws SQLException, StorageQueryException {
if (defaultRequiredFactorIds == null || defaultRequiredFactorIds.length == 0) {
return;
}

String QUERY = "INSERT INTO " + Config.getConfig(start).getTenantDefaultRequiredFactorIdsTable() + "(connection_uri_domain, app_id, tenant_id, factor_id, order_idx) VALUES (?, ?, ?, ?, ?);";
int orderIdx = 0;
for (String factorId : defaultRequiredFactorIds) {
int finalOrderIdx = orderIdx;
update(sqlCon, QUERY, pst -> {
pst.setString(1, tenantIdentifier.getConnectionUriDomain());
pst.setString(2, tenantIdentifier.getAppId());
pst.setString(3, tenantIdentifier.getTenantId());
pst.setString(4, factorId);
pst.setInt(5, finalOrderIdx);
});
orderIdx++;
}
}
}
Loading

0 comments on commit 6cd3c59

Please sign in to comment.