diff --git a/src/main/java/io/supertokens/storage/postgresql/Start.java b/src/main/java/io/supertokens/storage/postgresql/Start.java index 55b452dd..39799ecc 100644 --- a/src/main/java/io/supertokens/storage/postgresql/Start.java +++ b/src/main/java/io/supertokens/storage/postgresql/Start.java @@ -54,7 +54,6 @@ import io.supertokens.pluginInterface.multitenancy.exceptions.DuplicateThirdPartyIdException; import io.supertokens.pluginInterface.multitenancy.exceptions.TenantOrAppNotFoundException; import io.supertokens.pluginInterface.multitenancy.sqlStorage.MultitenancySQLStorage; -import io.supertokens.pluginInterface.oauth.exceptions.OAuth2ClientAlreadyExistsForAppException; import io.supertokens.pluginInterface.oauth.sqlStorage.OAuthSQLStorage; import io.supertokens.pluginInterface.passwordless.PasswordlessCode; import io.supertokens.pluginInterface.passwordless.PasswordlessDevice; @@ -123,7 +122,6 @@ public class Start private ResourceDistributor resourceDistributor = new ResourceDistributor(); private String processId; private HikariLoggingAppender appender; - private static final String APP_ID_KEY_NAME = "app_id"; private static final String ACCESS_TOKEN_SIGNING_KEY_NAME = "access_token_signing_key"; private static final String REFRESH_TOKEN_KEY_NAME = "refresh_token_key"; public static boolean isTesting = false; @@ -3083,7 +3081,7 @@ public int countUsersThatHaveMoreThanOneLoginMethodOrTOTPEnabledAndActiveSince(A } @Override - public boolean doesClientIdExistForThisApp(AppIdentifier appIdentifier, String clientId) + public boolean doesClientIdExistForApp(AppIdentifier appIdentifier, String clientId) throws StorageQueryException { try { return OAuthQueries.isClientIdForAppId(this, clientId, appIdentifier); @@ -3093,20 +3091,11 @@ public boolean doesClientIdExistForThisApp(AppIdentifier appIdentifier, String c } @Override - public void addClientForApp(AppIdentifier appIdentifier, String clientId) - throws StorageQueryException, OAuth2ClientAlreadyExistsForAppException { + public void addOrUpdateClientForApp(AppIdentifier appIdentifier, String clientId, boolean isClientCredentialsOnly) + throws StorageQueryException { try { - OAuthQueries.insertClientIdForAppId(this, clientId, appIdentifier); + OAuthQueries.insertClientIdForAppId(this, appIdentifier, clientId, isClientCredentialsOnly); } catch (SQLException e) { - - if (e instanceof PSQLException) { - PostgreSQLConfig config = Config.getConfig(this); - ServerErrorMessage serverMessage = ((PSQLException) e).getServerErrorMessage(); - - if (isPrimaryKeyError(serverMessage, config.getOAuthClientTable())) { - throw new OAuth2ClientAlreadyExistsForAppException(); - } - } throw new StorageQueryException(e); } } @@ -3130,6 +3119,82 @@ public List listClientsForApp(AppIdentifier appIdentifier) throws Storag } } + @Override + public void revoke(AppIdentifier appIdentifier, String targetType, String targetValue, long exp) + throws StorageQueryException { + try { + OAuthQueries.revoke(this, appIdentifier, targetType, targetValue, exp); + } catch (SQLException e) { + throw new StorageQueryException(e); + } + } + + @Override + public boolean isRevoked(AppIdentifier appIdentifier, String[] targetTypes, String[] targetValues, long issuedAt) + throws StorageQueryException { + try { + return OAuthQueries.isRevoked(this, appIdentifier, targetTypes, targetValues, issuedAt); + } catch (SQLException e) { + throw new StorageQueryException(e); + } + } + + @Override + public void addM2MToken(AppIdentifier appIdentifier, String clientId, long iat, long exp) + throws StorageQueryException { + try { + OAuthQueries.addM2MToken(this, appIdentifier, clientId, iat, exp); + } catch (SQLException e) { + throw new StorageQueryException(e); + } + } + + @Override + public void cleanUpExpiredAndRevokedTokens(AppIdentifier appIdentifier) throws StorageQueryException { + try { + OAuthQueries.cleanUpExpiredAndRevokedTokens(this, appIdentifier); + } catch (SQLException e) { + throw new StorageQueryException(e); + } + } + + @Override + public int countTotalNumberOfM2MTokensAlive(AppIdentifier appIdentifier) throws StorageQueryException { + try { + return OAuthQueries.countTotalNumberOfM2MTokensAlive(this, appIdentifier); + } catch (SQLException e) { + throw new StorageQueryException(e); + } + } + + @Override + public int countTotalNumberOfM2MTokensCreatedSince(AppIdentifier appIdentifier, long since) + throws StorageQueryException { + try { + return OAuthQueries.countTotalNumberOfM2MTokensCreatedSince(this, appIdentifier, since); + } catch (SQLException e) { + throw new StorageQueryException(e); + } + } + + @Override + public int countTotalNumberOfClientCredentialsOnlyClientsForApp(AppIdentifier appIdentifier) + throws StorageQueryException { + try { + return OAuthQueries.countTotalNumberOfClientsForApp(this, appIdentifier, true); + } catch (SQLException e) { + throw new StorageQueryException(e); + } + } + + @Override + public int countTotalNumberOfClientsForApp(AppIdentifier appIdentifier) throws StorageQueryException { + try { + return OAuthQueries.countTotalNumberOfClientsForApp(this, appIdentifier, false); + } catch (SQLException e) { + throw new StorageQueryException(e); + } + } @TestOnly public int getDbActivityCount(String dbname) throws SQLException, StorageQueryException { diff --git a/src/main/java/io/supertokens/storage/postgresql/config/PostgreSQLConfig.java b/src/main/java/io/supertokens/storage/postgresql/config/PostgreSQLConfig.java index 9de336ee..228719f5 100644 --- a/src/main/java/io/supertokens/storage/postgresql/config/PostgreSQLConfig.java +++ b/src/main/java/io/supertokens/storage/postgresql/config/PostgreSQLConfig.java @@ -439,10 +439,18 @@ public String getDashboardSessionsTable() { return addSchemaAndPrefixToTableName("dashboard_user_sessions"); } - public String getOAuthClientTable() { + public String getOAuthClientsTable() { return addSchemaAndPrefixToTableName("oauth_clients"); } + public String getOAuthRevokeTable() { + return addSchemaAndPrefixToTableName("oauth_revoke"); + } + + public String getOAuthM2MTokensTable() { + return addSchemaAndPrefixToTableName("oauth_m2m_tokens"); + } + public String getTotpUsersTable() { return addSchemaAndPrefixToTableName("totp_users"); } diff --git a/src/main/java/io/supertokens/storage/postgresql/queries/GeneralQueries.java b/src/main/java/io/supertokens/storage/postgresql/queries/GeneralQueries.java index a4a597e2..02b35de8 100644 --- a/src/main/java/io/supertokens/storage/postgresql/queries/GeneralQueries.java +++ b/src/main/java/io/supertokens/storage/postgresql/queries/GeneralQueries.java @@ -552,11 +552,29 @@ public static void createTablesIfNotExists(Start start, Connection con) throws S update(con, TOTPQueries.getQueryToCreateTenantIdIndexForUsedCodesTable(start), NO_OP_SETTER); } - if (!doesTableExists(start, con, Config.getConfig(start).getOAuthClientTable())) { + if (!doesTableExists(start, con, Config.getConfig(start).getOAuthClientsTable())) { getInstance(start).addState(CREATING_NEW_TABLE, null); update(start, OAuthQueries.getQueryToCreateOAuthClientTable(start), NO_OP_SETTER); } + if (!doesTableExists(start, con, Config.getConfig(start).getOAuthRevokeTable())) { + getInstance(start).addState(CREATING_NEW_TABLE, null); + update(start, OAuthQueries.getQueryToCreateOAuthRevokeTable(start), NO_OP_SETTER); + + // index + update(con, OAuthQueries.getQueryToCreateOAuthRevokeTimestampIndex(start), NO_OP_SETTER); + update(con, OAuthQueries.getQueryToCreateOAuthRevokeExpIndex(start), NO_OP_SETTER); + } + + if (!doesTableExists(start, con, Config.getConfig(start).getOAuthM2MTokensTable())) { + getInstance(start).addState(CREATING_NEW_TABLE, null); + update(start, OAuthQueries.getQueryToCreateOAuthM2MTokensTable(start), NO_OP_SETTER); + + // index + update(con, OAuthQueries.getQueryToCreateOAuthM2MTokenIatIndex(start), NO_OP_SETTER); + update(con, OAuthQueries.getQueryToCreateOAuthM2MTokenExpIndex(start), NO_OP_SETTER); + } + } catch (Exception e) { if (e.getMessage().contains("schema") && e.getMessage().contains("does not exist") && numberOfRetries < 1) { @@ -627,7 +645,9 @@ public static void deleteAllTables(Start start) throws SQLException, StorageQuer + getConfig(start).getUserRolesTable() + "," + getConfig(start).getDashboardUsersTable() + "," + getConfig(start).getDashboardSessionsTable() + "," - + getConfig(start).getOAuthClientTable() + "," + + getConfig(start).getOAuthClientsTable() + "," + + getConfig(start).getOAuthRevokeTable() + "," + + getConfig(start).getOAuthM2MTokensTable() + "," + getConfig(start).getTotpUsedCodesTable() + "," + getConfig(start).getTotpUserDevicesTable() + "," + getConfig(start).getTotpUsersTable(); diff --git a/src/main/java/io/supertokens/storage/postgresql/queries/OAuthQueries.java b/src/main/java/io/supertokens/storage/postgresql/queries/OAuthQueries.java index c0c99f73..3775a02b 100644 --- a/src/main/java/io/supertokens/storage/postgresql/queries/OAuthQueries.java +++ b/src/main/java/io/supertokens/storage/postgresql/queries/OAuthQueries.java @@ -14,15 +14,15 @@ import static io.supertokens.storage.postgresql.QueryExecutorTemplate.execute; import static io.supertokens.storage.postgresql.QueryExecutorTemplate.update; - public class OAuthQueries { public static String getQueryToCreateOAuthClientTable(Start start) { String schema = Config.getConfig(start).getTableSchema(); - String oAuth2ClientTable = Config.getConfig(start).getOAuthClientTable(); + String oAuth2ClientTable = Config.getConfig(start).getOAuthClientsTable(); // @formatter:off return "CREATE TABLE IF NOT EXISTS " + oAuth2ClientTable + " (" + "app_id VARCHAR(64) DEFAULT 'public'," + "client_id VARCHAR(128) NOT NULL," + + "is_client_credentials_only BOOLEAN NOT NULL," + "CONSTRAINT " + Utils.getConstraintName(schema, oAuth2ClientTable, "client_id", "pkey") + " PRIMARY KEY (app_id, client_id)," + "CONSTRAINT " + Utils.getConstraintName(schema, oAuth2ClientTable, "app_id", "fkey") @@ -32,9 +32,70 @@ public static String getQueryToCreateOAuthClientTable(Start start) { // @formatter:on } + public static String getQueryToCreateOAuthRevokeTable(Start start) { + String schema = Config.getConfig(start).getTableSchema(); + String oAuth2ClientTable = Config.getConfig(start).getOAuthRevokeTable(); + // @formatter:off + return "CREATE TABLE IF NOT EXISTS " + oAuth2ClientTable + " (" + + "app_id VARCHAR(64) DEFAULT 'public'," + + "target_type VARCHAR(16) NOT NULL," + + "target_value VARCHAR(128) NOT NULL," + + "timestamp BIGINT NOT NULL," + + "exp BIGINT NOT NULL," + + "CONSTRAINT " + Utils.getConstraintName(schema, oAuth2ClientTable, "client_id", "pkey") + + " PRIMARY KEY (app_id, target_type, target_value)," + + "CONSTRAINT " + Utils.getConstraintName(schema, oAuth2ClientTable, "app_id", "fkey") + + " FOREIGN KEY(app_id)" + + " REFERENCES " + Config.getConfig(start).getAppsTable() + "(app_id) ON DELETE CASCADE" + + ");"; + // @formatter:on + } + + public static String getQueryToCreateOAuthRevokeTimestampIndex(Start start) { + String oAuth2ClientTable = Config.getConfig(start).getOAuthRevokeTable(); + return "CREATE INDEX IF NOT EXISTS oauth_revoke_timestamp_index ON " + + oAuth2ClientTable + "(timestamp DESC, app_id DESC);"; + } + + public static String getQueryToCreateOAuthRevokeExpIndex(Start start) { + String oAuth2ClientTable = Config.getConfig(start).getOAuthRevokeTable(); + return "CREATE INDEX IF NOT EXISTS oauth_revoke_exp_index ON " + + oAuth2ClientTable + "(exp DESC, app_id DESC);"; + } + + public static String getQueryToCreateOAuthM2MTokensTable(Start start) { + String schema = Config.getConfig(start).getTableSchema(); + String oAuth2ClientTable = Config.getConfig(start).getOAuthM2MTokensTable(); + // @formatter:off + return "CREATE TABLE IF NOT EXISTS " + oAuth2ClientTable + " (" + + "app_id VARCHAR(64) DEFAULT 'public'," + + "client_id VARCHAR(128) NOT NULL," + + "iat BIGINT NOT NULL," + + "exp BIGINT NOT NULL," + + "CONSTRAINT " + Utils.getConstraintName(schema, oAuth2ClientTable, "client_id", "pkey") + + " PRIMARY KEY (app_id, client_id, iat)," + + "CONSTRAINT " + Utils.getConstraintName(schema, oAuth2ClientTable, "app_id", "fkey") + + " FOREIGN KEY(app_id)" + + " REFERENCES " + Config.getConfig(start).getAppsTable() + "(app_id) ON DELETE CASCADE" + + ");"; + // @formatter:on + } + + public static String getQueryToCreateOAuthM2MTokenIatIndex(Start start) { + String oAuth2ClientTable = Config.getConfig(start).getOAuthM2MTokensTable(); + return "CREATE INDEX IF NOT EXISTS oauth_m2m_token_iat_index ON " + + oAuth2ClientTable + "(iat DESC, app_id DESC);"; + } + + public static String getQueryToCreateOAuthM2MTokenExpIndex(Start start) { + String oAuth2ClientTable = Config.getConfig(start).getOAuthM2MTokensTable(); + return "CREATE INDEX IF NOT EXISTS oauth_m2m_token_exp_index ON " + + oAuth2ClientTable + "(exp DESC, app_id DESC);"; + } + public static boolean isClientIdForAppId(Start start, String clientId, AppIdentifier appIdentifier) throws SQLException, StorageQueryException { - String QUERY = "SELECT app_id FROM " + Config.getConfig(start).getOAuthClientTable() + + String QUERY = "SELECT app_id FROM " + Config.getConfig(start).getOAuthClientsTable() + " WHERE client_id = ? AND app_id = ?"; return execute(start, QUERY, pst -> { @@ -45,7 +106,7 @@ public static boolean isClientIdForAppId(Start start, String clientId, AppIdenti public static List listClientsForApp(Start start, AppIdentifier appIdentifier) throws SQLException, StorageQueryException { - String QUERY = "SELECT client_id FROM " + Config.getConfig(start).getOAuthClientTable() + + String QUERY = "SELECT client_id FROM " + Config.getConfig(start).getOAuthClientsTable() + " WHERE app_id = ?"; return execute(start, QUERY, pst -> { pst.setString(1, appIdentifier.getAppId()); @@ -58,19 +119,23 @@ public static List listClientsForApp(Start start, AppIdentifier appIdent }); } - public static void insertClientIdForAppId(Start start, String clientId, AppIdentifier appIdentifier) + public static void insertClientIdForAppId(Start start, AppIdentifier appIdentifier, String clientId, + boolean isClientCredentialsOnly) throws SQLException, StorageQueryException { - String INSERT = "INSERT INTO " + Config.getConfig(start).getOAuthClientTable() - + "(app_id, client_id) VALUES(?, ?)"; + String INSERT = "INSERT INTO " + Config.getConfig(start).getOAuthClientsTable() + + "(app_id, client_id, is_client_credentials_only) VALUES(?, ?, ?) " + + "ON CONFLICT (app_id, client_id) DO UPDATE SET is_client_credentials_only = ?"; update(start, INSERT, pst -> { pst.setString(1, appIdentifier.getAppId()); pst.setString(2, clientId); + pst.setBoolean(3, isClientCredentialsOnly); + pst.setBoolean(4, isClientCredentialsOnly); }); } public static boolean deleteClientIdForAppId(Start start, String clientId, AppIdentifier appIdentifier) throws SQLException, StorageQueryException { - String DELETE = "DELETE FROM " + Config.getConfig(start).getOAuthClientTable() + String DELETE = "DELETE FROM " + Config.getConfig(start).getOAuthClientsTable() + " WHERE app_id = ? AND client_id = ?"; int numberOfRow = update(start, DELETE, pst -> { pst.setString(1, appIdentifier.getAppId()); @@ -78,4 +143,148 @@ public static boolean deleteClientIdForAppId(Start start, String clientId, AppId }); return numberOfRow > 0; } + + public static void revoke(Start start, AppIdentifier appIdentifier, String targetType, String targetValue, long exp) + throws SQLException, StorageQueryException { + String INSERT = "INSERT INTO " + Config.getConfig(start).getOAuthRevokeTable() + + "(app_id, target_type, target_value, timestamp, exp) VALUES (?, ?, ?, ?, ?) " + + "ON CONFLICT (app_id, target_type, target_value) DO UPDATE SET timestamp = ?, exp = ?"; + + long currentTime = System.currentTimeMillis() / 1000; + update(start, INSERT, pst -> { + pst.setString(1, appIdentifier.getAppId()); + pst.setString(2, targetType); + pst.setString(3, targetValue); + pst.setLong(4, currentTime); + pst.setLong(5, exp); + pst.setLong(6, currentTime); + pst.setLong(7, exp); + }); + } + + public static boolean isRevoked(Start start, AppIdentifier appIdentifier, String[] targetTypes, + String[] targetValues, long issuedAt) + throws SQLException, StorageQueryException { + String QUERY = "SELECT app_id FROM " + Config.getConfig(start).getOAuthRevokeTable() + + " WHERE app_id = ? AND timestamp > ? AND ("; + + for (int i = 0; i < targetTypes.length; i++) { + QUERY += "(target_type = ? AND target_value = ?)"; + + if (i < targetTypes.length - 1) { + QUERY += " OR "; + } + } + + QUERY += ")"; + + return execute(start, QUERY, pst -> { + pst.setString(1, appIdentifier.getAppId()); + pst.setLong(2, issuedAt); + + int index = 3; + for (int i = 0; i < targetTypes.length; i++) { + pst.setString(index, targetTypes[i]); + index++; + pst.setString(index, targetValues[i]); + index++; + } + }, ResultSet::next); + } + + public static int countTotalNumberOfClientsForApp(Start start, AppIdentifier appIdentifier, + boolean filterByClientCredentialsOnly) throws SQLException, StorageQueryException { + if (filterByClientCredentialsOnly) { + String QUERY = "SELECT COUNT(*) as c FROM " + Config.getConfig(start).getOAuthClientsTable() + + " WHERE app_id = ? AND is_client_credentials_only = ?"; + return execute(start, QUERY, pst -> { + pst.setString(1, appIdentifier.getAppId()); + pst.setBoolean(2, true); + }, result -> { + if (result.next()) { + return result.getInt("c"); + } + return 0; + }); + } else { + String QUERY = "SELECT COUNT(*) as c FROM " + Config.getConfig(start).getOAuthClientsTable() + + " WHERE app_id = ?"; + return execute(start, QUERY, pst -> { + pst.setString(1, appIdentifier.getAppId()); + }, result -> { + if (result.next()) { + return result.getInt("c"); + } + return 0; + }); + } + } + + public static int countTotalNumberOfM2MTokensAlive(Start start, AppIdentifier appIdentifier) + throws SQLException, StorageQueryException { + String QUERY = "SELECT COUNT(*) as c FROM " + Config.getConfig(start).getOAuthM2MTokensTable() + + " WHERE app_id = ? AND exp > ?"; + return execute(start, QUERY, pst -> { + pst.setString(1, appIdentifier.getAppId()); + pst.setLong(2, System.currentTimeMillis()/1000); + }, result -> { + if (result.next()) { + return result.getInt("c"); + } + return 0; + }); + } + + public static int countTotalNumberOfM2MTokensCreatedSince(Start start, AppIdentifier appIdentifier, long since) + throws SQLException, StorageQueryException { + String QUERY = "SELECT COUNT(*) as c FROM " + Config.getConfig(start).getOAuthM2MTokensTable() + + " WHERE app_id = ? AND iat >= ?"; + return execute(start, QUERY, pst -> { + pst.setString(1, appIdentifier.getAppId()); + pst.setLong(2, since / 1000); + }, result -> { + if (result.next()) { + return result.getInt("c"); + } + return 0; + }); + } + + public static void addM2MToken(Start start, AppIdentifier appIdentifier, String clientId, long iat, long exp) + throws SQLException, StorageQueryException { + String QUERY = "INSERT INTO " + Config.getConfig(start).getOAuthM2MTokensTable() + + " (app_id, client_id, iat, exp) VALUES (?, ?, ?, ?)"; + update(start, QUERY, pst -> { + pst.setString(1, appIdentifier.getAppId()); + pst.setString(2, clientId); + pst.setLong(3, iat); + pst.setLong(4, exp); + }); + } + + public static void cleanUpExpiredAndRevokedTokens(Start start, AppIdentifier appIdentifier) throws SQLException, StorageQueryException { + { + // delete expired M2M tokens + String QUERY = "DELETE FROM " + Config.getConfig(start).getOAuthM2MTokensTable() + + " WHERE app_id = ? AND exp < ?"; + + long timestamp = System.currentTimeMillis() / 1000 - 3600 * 24 * 31; // expired 31 days ago + update(start, QUERY, pst -> { + pst.setString(1, appIdentifier.getAppId()); + pst.setLong(2, timestamp); + }); + } + + { + // delete expired revoked tokens + String QUERY = "DELETE FROM " + Config.getConfig(start).getOAuthRevokeTable() + + " WHERE app_id = ? AND exp < ?"; + + long timestamp = System.currentTimeMillis() / 1000 - 3600 * 24 * 31; // expired 31 days ago + update(start, QUERY, pst -> { + pst.setString(1, appIdentifier.getAppId()); + pst.setLong(2, timestamp); + }); + } + } }