From a7ccacbdb717a03716e49682b331bcc75142a40c Mon Sep 17 00:00:00 2001 From: Sattvik Chakravarthy Date: Tue, 6 Feb 2024 12:03:48 +0530 Subject: [PATCH] fix: Load only cud (#917) * fix: update config and validateAndNormalize * fix: impl * fix: PR comments * fix: cleanup * fix: cleanup * fix: pr comments * fix: pr comments * fix: tests * fix: changelog * fix: 400 error * fix: cuds from db --- CHANGELOG.md | 1 + config.yaml | 5 + devConfig.yaml | 4 + .../io/supertokens/config/CoreConfig.java | 19 ++ .../multitenancy/MultitenancyHelper.java | 46 +++- .../supertokens/webserver/WebserverAPI.java | 29 +- .../api/multitenancy/BaseCreateOrUpdate.java | 10 + .../test/multitenant/ConfigTest.java | 4 +- .../test/multitenant/LoadOnlyCUDTest.java | 253 ++++++++++++++++++ 9 files changed, 353 insertions(+), 18 deletions(-) create mode 100644 src/test/java/io/supertokens/test/multitenant/LoadOnlyCUDTest.java diff --git a/CHANGELOG.md b/CHANGELOG.md index 4365d6623..21e63f238 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,7 @@ to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). ## [7.0.17] - 2024-02-06 - Fixes issue where error logs were printed to StdOut instead of StdErr. +- Adds new config `supertokens_saas_load_only_cud` that makes the core instance load a particular CUD only, irrespective of the CUDs present in the db. ## [7.0.16] - 2023-12-04 diff --git a/config.yaml b/config.yaml index bc6e7cbce..fdb96d4ba 100644 --- a/config.yaml +++ b/config.yaml @@ -146,3 +146,8 @@ core_config_version: 0 # when CDI version is not specified in the request. When set to null, the core will assume the latest version of the # CDI. # supertokens_max_cdi_version: + + +# (OPTIONAL | Default: null) string value. If specified, the supertokens service will only load the specified CUD even +# if there are more CUDs in the database and block all other CUDs from being used from this instance. +# supertokens_saas_load_only_cud: diff --git a/devConfig.yaml b/devConfig.yaml index 73ccf220d..276b35d42 100644 --- a/devConfig.yaml +++ b/devConfig.yaml @@ -147,3 +147,7 @@ disable_telemetry: true # when CDI version is not specified in the request. When set to null, the core will assume the latest version of the # CDI. # supertokens_max_cdi_version: + +# (OPTIONAL | Default: null) string value. If specified, the supertokens service will only load the specified CUD even +# if there are more CUDs in the database and block all other CUDs from being used from this instance. +# supertokens_saas_load_only_cud: diff --git a/src/main/java/io/supertokens/config/CoreConfig.java b/src/main/java/io/supertokens/config/CoreConfig.java index ba291a4b2..3de06caa7 100644 --- a/src/main/java/io/supertokens/config/CoreConfig.java +++ b/src/main/java/io/supertokens/config/CoreConfig.java @@ -30,7 +30,9 @@ import io.supertokens.pluginInterface.LOG_LEVEL; import io.supertokens.pluginInterface.exceptions.InvalidConfigException; import io.supertokens.utils.SemVer; +import io.supertokens.webserver.Utils; import io.supertokens.webserver.WebserverAPI; +import jakarta.servlet.ServletException; import org.apache.catalina.filters.RemoteAddrFilter; import org.jetbrains.annotations.TestOnly; @@ -197,6 +199,10 @@ public class CoreConfig { @JsonProperty private String supertokens_max_cdi_version = null; + @ConfigYamlOnly + @JsonProperty + private String supertokens_saas_load_only_cud = null; + @IgnoreForAnnotationCheck private Set allowedLogLevels = null; @@ -254,6 +260,10 @@ public String getBasePath() { return base_path; } + public String getSuperTokensLoadOnlyCUD() { + return supertokens_saas_load_only_cud; + } + public enum PASSWORD_HASHING_ALG { ARGON2, BCRYPT, FIREBASE_SCRYPT } @@ -667,6 +677,15 @@ void normalizeAndValidate(Main main, boolean includeConfigFilePath) throws Inval host = cliHost; } + if (supertokens_saas_load_only_cud != null) { + try { + supertokens_saas_load_only_cud = + Utils.normalizeAndValidateConnectionUriDomain(supertokens_saas_load_only_cud, true); + } catch (ServletException e) { + throw new InvalidConfigException("supertokens_saas_load_only_cud is invalid"); + } + } + access_token_validity = access_token_validity * 1000; access_token_dynamic_signing_key_update_interval = access_token_dynamic_signing_key_update_interval * 3600 * 1000; refresh_token_validity = refresh_token_validity * 60 * 1000; diff --git a/src/main/java/io/supertokens/multitenancy/MultitenancyHelper.java b/src/main/java/io/supertokens/multitenancy/MultitenancyHelper.java index fe9397622..ad0433238 100644 --- a/src/main/java/io/supertokens/multitenancy/MultitenancyHelper.java +++ b/src/main/java/io/supertokens/multitenancy/MultitenancyHelper.java @@ -51,9 +51,20 @@ public class MultitenancyHelper extends ResourceDistributor.SingletonResource { private Main main; private TenantConfig[] tenantConfigs; + // when the core has `supertokens_saas_load_only_cud` set, the tenantConfigs array will be filtered + // based on the config value. However, we need to keep all the list of CUDs from the db to be able + // to check if the CUD is present in the DB or not, while processing the requests. + private final Set dangerous_allCUDsFromDb = new HashSet<>(); + private MultitenancyHelper(Main main) throws StorageQueryException { this.main = main; - this.tenantConfigs = getAllTenantsFromDb(); + TenantConfig[] allTenantsFromDb = getAllTenantsFromDb(); + this.tenantConfigs = this.getFilteredTenantConfigs(allTenantsFromDb); + this.dangerous_allCUDsFromDb.clear(); + + for (TenantConfig config : allTenantsFromDb) { + this.dangerous_allCUDsFromDb.add(config.tenantIdentifier.getConnectionUriDomain()); + } } public static MultitenancyHelper getInstance(Main main) { @@ -108,10 +119,11 @@ public List refreshTenantsInCoreBasedOnChangesInCoreConfigOrIf return main.getResourceDistributor().withResourceDistributorLock(() -> { try { TenantConfig[] tenantsFromDb = getAllTenantsFromDb(); + TenantConfig[] filteredTenantsFromDb = this.getFilteredTenantConfigs(tenantsFromDb); Map normalizedTenantsFromDb = Config.getNormalisedConfigsForAllTenants( - tenantsFromDb, Config.getBaseConfigAsJsonObject(main)); + filteredTenantsFromDb, Config.getBaseConfigAsJsonObject(main)); Map normalizedTenantsFromMemory = Config.getNormalisedConfigsForAllTenants( @@ -129,9 +141,14 @@ public List refreshTenantsInCoreBasedOnChangesInCoreConfigOrIf } } - boolean sameNumberOfTenants = tenantsFromDb.length == this.tenantConfigs.length; + boolean sameNumberOfTenants = + filteredTenantsFromDb.length == this.tenantConfigs.length; - this.tenantConfigs = tenantsFromDb; + this.dangerous_allCUDsFromDb.clear(); + for (TenantConfig tenant : tenantsFromDb) { + this.dangerous_allCUDsFromDb.add(tenant.tenantIdentifier.getConnectionUriDomain()); + } + this.tenantConfigs = filteredTenantsFromDb; if (tenantsThatChanged.size() == 0 && sameNumberOfTenants) { return tenantsThatChanged; } @@ -190,7 +207,7 @@ public void loadStorageLayer() throws IOException, InvalidConfigException { public void loadFeatureFlag(List tenantsThatChanged) { List apps = new ArrayList<>(); Set appsSet = new HashSet<>(); - for (TenantConfig t : tenantConfigs) { + for (TenantConfig t : this.tenantConfigs) { if (appsSet.contains(t.tenantIdentifier.toAppIdentifier())) { continue; } @@ -204,7 +221,7 @@ public void loadSigningKeys(List tenantsThatChanged) throws UnsupportedJWTSigningAlgorithmException { List apps = new ArrayList<>(); Set appsSet = new HashSet<>(); - for (TenantConfig t : tenantConfigs) { + for (TenantConfig t : this.tenantConfigs) { if (appsSet.contains(t.tenantIdentifier.toAppIdentifier())) { continue; } @@ -238,4 +255,21 @@ public TenantConfig[] getAllTenants() { throw new IllegalStateException(e); } } + + private TenantConfig[] getFilteredTenantConfigs(TenantConfig[] inputTenantConfigs) { + String loadOnlyCUD = Config.getBaseConfig(main).getSuperTokensLoadOnlyCUD(); + + if (loadOnlyCUD == null) { + return inputTenantConfigs; + } + + return Arrays.stream(inputTenantConfigs) + .filter(tenantConfig -> tenantConfig.tenantIdentifier.getConnectionUriDomain().equals(loadOnlyCUD) + || tenantConfig.tenantIdentifier.getConnectionUriDomain().equals(TenantIdentifier.DEFAULT_CONNECTION_URI)) + .toArray(TenantConfig[]::new); + } + + public boolean isConnectionUriDomainPresentInDb(String cud) { + return this.dangerous_allCUDsFromDb.contains(cud); + } } diff --git a/src/main/java/io/supertokens/webserver/WebserverAPI.java b/src/main/java/io/supertokens/webserver/WebserverAPI.java index a94e04d83..a87130a75 100644 --- a/src/main/java/io/supertokens/webserver/WebserverAPI.java +++ b/src/main/java/io/supertokens/webserver/WebserverAPI.java @@ -24,6 +24,7 @@ import io.supertokens.config.CoreConfig; import io.supertokens.exceptions.QuitProgramException; import io.supertokens.featureflag.exceptions.FeatureNotEnabledException; +import io.supertokens.multitenancy.MultitenancyHelper; import io.supertokens.multitenancy.exception.BadPermissionException; import io.supertokens.output.Logging; import io.supertokens.pluginInterface.Storage; @@ -287,15 +288,18 @@ private String getConnectionUriDomain(HttpServletRequest req) throws ServletExce String connectionUriDomain = req.getServerName(); connectionUriDomain = Utils.normalizeAndValidateConnectionUriDomain(connectionUriDomain, false); - try { - if (Config.getConfig(new TenantIdentifier(connectionUriDomain, null, null), main) == - Config.getConfig(new TenantIdentifier(null, null, null), main)) { - return null; + if (MultitenancyHelper.getInstance(main).isConnectionUriDomainPresentInDb(connectionUriDomain)) { + CoreConfig baseConfig = Config.getBaseConfig(main); + if (baseConfig.getSuperTokensLoadOnlyCUD() != null) { + if (!connectionUriDomain.equals(baseConfig.getSuperTokensLoadOnlyCUD())) { + throw new ServletException(new BadRequestException("Connection URI domain is disallowed")); + } } - } catch (TenantOrAppNotFoundException e) { - throw new IllegalStateException(e); + + return connectionUriDomain; } - return connectionUriDomain; + + return null; } @TestOnly @@ -481,10 +485,13 @@ protected void service(HttpServletRequest req, HttpServletResponse resp) throws } Logging.info(main, tenantIdentifier, "API ended: " + req.getRequestURI() + ". Method: " + req.getMethod(), false); - try { - RequestStats.getInstance(main, tenantIdentifier.toAppIdentifier()).updateRequestStats(); - } catch (TenantOrAppNotFoundException e) { - // Ignore the error as we would have already sent the response for tenantNotFound + + if (tenantIdentifier != null) { + try { + RequestStats.getInstance(main, tenantIdentifier.toAppIdentifier()).updateRequestStats(); + } catch (TenantOrAppNotFoundException e) { + // Ignore the error as we would have already sent the response for tenantNotFound + } } } diff --git a/src/main/java/io/supertokens/webserver/api/multitenancy/BaseCreateOrUpdate.java b/src/main/java/io/supertokens/webserver/api/multitenancy/BaseCreateOrUpdate.java index ac930fb7b..87bc319b5 100644 --- a/src/main/java/io/supertokens/webserver/api/multitenancy/BaseCreateOrUpdate.java +++ b/src/main/java/io/supertokens/webserver/api/multitenancy/BaseCreateOrUpdate.java @@ -19,6 +19,8 @@ import com.google.gson.JsonElement; import com.google.gson.JsonObject; import io.supertokens.Main; +import io.supertokens.config.Config; +import io.supertokens.config.CoreConfig; import io.supertokens.featureflag.exceptions.FeatureNotEnabledException; import io.supertokens.multitenancy.Multitenancy; import io.supertokens.multitenancy.exception.BadPermissionException; @@ -49,6 +51,14 @@ protected void handle(HttpServletRequest req, TenantIdentifier sourceTenantIdent HttpServletResponse resp) throws ServletException, IOException { + CoreConfig baseConfig = Config.getBaseConfig(main); + if (baseConfig.getSuperTokensLoadOnlyCUD() != null) { + if (!(targetTenantIdentifier.getConnectionUriDomain().equals(TenantIdentifier.DEFAULT_CONNECTION_URI) || targetTenantIdentifier.getConnectionUriDomain().equals(baseConfig.getSuperTokensLoadOnlyCUD()))) { + throw new ServletException(new BadRequestException("Creation of connection uri domain or app or " + + "tenant is disallowed")); + } + } + TenantConfig tenantConfig = Multitenancy.getTenantInfo(main, new TenantIdentifier(targetTenantIdentifier.getConnectionUriDomain(), targetTenantIdentifier.getAppId(), targetTenantIdentifier.getTenantId())); diff --git a/src/test/java/io/supertokens/test/multitenant/ConfigTest.java b/src/test/java/io/supertokens/test/multitenant/ConfigTest.java index 2885d4318..13cdfe6be 100644 --- a/src/test/java/io/supertokens/test/multitenant/ConfigTest.java +++ b/src/test/java/io/supertokens/test/multitenant/ConfigTest.java @@ -1914,6 +1914,7 @@ public void testAllConflictingConfigs() throws Exception { "argon2_memory_kb", "argon2_parallelism", "bcrypt_log_rounds", + "supertokens_saas_load_only_cud" }; Object[] disallowedValues = new Object[]{ 3567, // port @@ -1930,6 +1931,7 @@ public void testAllConflictingConfigs() throws Exception { 87795, // argon2_memory_kb 2, // argon2_parallelism 11, // bcrypt_log_rounds + "mydomain.com", // supertokens_saas_load_only_cud }; process.kill(); @@ -1995,7 +1997,7 @@ public void testAllConflictingConfigs() throws Exception { new Object[]{true, false}, // disable_telemetry new Object[]{"BCRYPT", "ARGON2"}, // password_hashing_alg new Object[]{"abcd1234abcd1234abcd1234abcd1234", "qwer1234qwer1234qwer1234qwer1234"}, // firebase_password_hashing_signer_key - new Object[]{"2.21", "3.0"} // supertokens_max_cdi_version + new Object[]{"2.21", "3.0"}, // supertokens_max_cdi_version }; for (int i=0; i> uniqueUserPoolIdsTenants = StorageLayer.getTenantsWithUniqueUserPoolId(process.getProcess()); + Cronjobs.addCronjob(process.getProcess(), LoadOnlyCUDTest.PerAppCronjob.getInstance(process.getProcess(), uniqueUserPoolIdsTenants)); + + Thread.sleep(3000); + Set appIdentifiersFromCron = PerAppCronjob.getInstance(process.getProcess(), uniqueUserPoolIdsTenants).appIdentifiers; + assertEquals(2, appIdentifiersFromCron.size()); + for (AppIdentifier app : appIdentifiersFromCron) { + assertNotEquals("localhost.org", app.getConnectionUriDomain()); + } + + process.kill(); + assertNotNull(process.checkOrWaitForEvent(ProcessState.PROCESS_STATE.STOPPED)); + } + + static class PerAppCronjob extends CronTask { + private static final String RESOURCE_ID = "io.supertokens.test.CronjobTest.NormalCronjob"; + + private PerAppCronjob(Main main, List> tenantsInfo) { + super("PerTenantCronjob", main, tenantsInfo, true); + } + + Set appIdentifiers = new HashSet<>(); + + public static LoadOnlyCUDTest.PerAppCronjob getInstance(Main main, List> tenantsInfo) { + try { + return (LoadOnlyCUDTest.PerAppCronjob) main.getResourceDistributor().getResource(new TenantIdentifier(null, null, null), RESOURCE_ID); + } catch (TenantOrAppNotFoundException e) { + return (LoadOnlyCUDTest.PerAppCronjob) main.getResourceDistributor() + .setResource(new TenantIdentifier(null, null, null), RESOURCE_ID, new LoadOnlyCUDTest.PerAppCronjob(main, tenantsInfo)); + } + } + + @Override + public int getIntervalTimeSeconds() { + return 1; + } + + @Override + public int getInitialWaitTimeSeconds() { + return 0; + } + + @Override + protected void doTaskPerApp(AppIdentifier app) throws Exception { + appIdentifiers.add(app); + } + } +} \ No newline at end of file