Skip to content

Commit

Permalink
fix: Load only cud (#917)
Browse files Browse the repository at this point in the history
* fix: update config and validateAndNormalize

* fix: impl

* fix: PR comments

* fix: cleanup

* fix: cleanup

* fix: pr comments

* fix: pr comments

* fix: tests

* fix: changelog

* fix: 400 error

* fix: cuds from db
  • Loading branch information
sattvikc authored Feb 6, 2024
1 parent 7d341c4 commit a7ccacb
Show file tree
Hide file tree
Showing 9 changed files with 353 additions and 18 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
## [7.0.17] - 2024-02-06

- Fixes issue where error logs were printed to StdOut instead of StdErr.
- Adds new config `supertokens_saas_load_only_cud` that makes the core instance load a particular CUD only, irrespective of the CUDs present in the db.

## [7.0.16] - 2023-12-04

Expand Down
5 changes: 5 additions & 0 deletions config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -146,3 +146,8 @@ core_config_version: 0
# when CDI version is not specified in the request. When set to null, the core will assume the latest version of the
# CDI.
# supertokens_max_cdi_version:


# (OPTIONAL | Default: null) string value. If specified, the supertokens service will only load the specified CUD even
# if there are more CUDs in the database and block all other CUDs from being used from this instance.
# supertokens_saas_load_only_cud:
4 changes: 4 additions & 0 deletions devConfig.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -147,3 +147,7 @@ disable_telemetry: true
# when CDI version is not specified in the request. When set to null, the core will assume the latest version of the
# CDI.
# supertokens_max_cdi_version:

# (OPTIONAL | Default: null) string value. If specified, the supertokens service will only load the specified CUD even
# if there are more CUDs in the database and block all other CUDs from being used from this instance.
# supertokens_saas_load_only_cud:
19 changes: 19 additions & 0 deletions src/main/java/io/supertokens/config/CoreConfig.java
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,9 @@
import io.supertokens.pluginInterface.LOG_LEVEL;
import io.supertokens.pluginInterface.exceptions.InvalidConfigException;
import io.supertokens.utils.SemVer;
import io.supertokens.webserver.Utils;
import io.supertokens.webserver.WebserverAPI;
import jakarta.servlet.ServletException;
import org.apache.catalina.filters.RemoteAddrFilter;
import org.jetbrains.annotations.TestOnly;

Expand Down Expand Up @@ -197,6 +199,10 @@ public class CoreConfig {
@JsonProperty
private String supertokens_max_cdi_version = null;

@ConfigYamlOnly
@JsonProperty
private String supertokens_saas_load_only_cud = null;

@IgnoreForAnnotationCheck
private Set<LOG_LEVEL> allowedLogLevels = null;

Expand Down Expand Up @@ -254,6 +260,10 @@ public String getBasePath() {
return base_path;
}

public String getSuperTokensLoadOnlyCUD() {
return supertokens_saas_load_only_cud;
}

public enum PASSWORD_HASHING_ALG {
ARGON2, BCRYPT, FIREBASE_SCRYPT
}
Expand Down Expand Up @@ -667,6 +677,15 @@ void normalizeAndValidate(Main main, boolean includeConfigFilePath) throws Inval
host = cliHost;
}

if (supertokens_saas_load_only_cud != null) {
try {
supertokens_saas_load_only_cud =
Utils.normalizeAndValidateConnectionUriDomain(supertokens_saas_load_only_cud, true);
} catch (ServletException e) {
throw new InvalidConfigException("supertokens_saas_load_only_cud is invalid");
}
}

access_token_validity = access_token_validity * 1000;
access_token_dynamic_signing_key_update_interval = access_token_dynamic_signing_key_update_interval * 3600 * 1000;
refresh_token_validity = refresh_token_validity * 60 * 1000;
Expand Down
46 changes: 40 additions & 6 deletions src/main/java/io/supertokens/multitenancy/MultitenancyHelper.java
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,20 @@ public class MultitenancyHelper extends ResourceDistributor.SingletonResource {
private Main main;
private TenantConfig[] tenantConfigs;

// when the core has `supertokens_saas_load_only_cud` set, the tenantConfigs array will be filtered
// based on the config value. However, we need to keep all the list of CUDs from the db to be able
// to check if the CUD is present in the DB or not, while processing the requests.
private final Set<String> dangerous_allCUDsFromDb = new HashSet<>();

private MultitenancyHelper(Main main) throws StorageQueryException {
this.main = main;
this.tenantConfigs = getAllTenantsFromDb();
TenantConfig[] allTenantsFromDb = getAllTenantsFromDb();
this.tenantConfigs = this.getFilteredTenantConfigs(allTenantsFromDb);
this.dangerous_allCUDsFromDb.clear();

for (TenantConfig config : allTenantsFromDb) {
this.dangerous_allCUDsFromDb.add(config.tenantIdentifier.getConnectionUriDomain());
}
}

public static MultitenancyHelper getInstance(Main main) {
Expand Down Expand Up @@ -108,10 +119,11 @@ public List<TenantIdentifier> refreshTenantsInCoreBasedOnChangesInCoreConfigOrIf
return main.getResourceDistributor().withResourceDistributorLock(() -> {
try {
TenantConfig[] tenantsFromDb = getAllTenantsFromDb();
TenantConfig[] filteredTenantsFromDb = this.getFilteredTenantConfigs(tenantsFromDb);

Map<ResourceDistributor.KeyClass, JsonObject> normalizedTenantsFromDb =
Config.getNormalisedConfigsForAllTenants(
tenantsFromDb, Config.getBaseConfigAsJsonObject(main));
filteredTenantsFromDb, Config.getBaseConfigAsJsonObject(main));

Map<ResourceDistributor.KeyClass, JsonObject> normalizedTenantsFromMemory =
Config.getNormalisedConfigsForAllTenants(
Expand All @@ -129,9 +141,14 @@ public List<TenantIdentifier> refreshTenantsInCoreBasedOnChangesInCoreConfigOrIf
}
}

boolean sameNumberOfTenants = tenantsFromDb.length == this.tenantConfigs.length;
boolean sameNumberOfTenants =
filteredTenantsFromDb.length == this.tenantConfigs.length;

this.tenantConfigs = tenantsFromDb;
this.dangerous_allCUDsFromDb.clear();
for (TenantConfig tenant : tenantsFromDb) {
this.dangerous_allCUDsFromDb.add(tenant.tenantIdentifier.getConnectionUriDomain());
}
this.tenantConfigs = filteredTenantsFromDb;
if (tenantsThatChanged.size() == 0 && sameNumberOfTenants) {
return tenantsThatChanged;
}
Expand Down Expand Up @@ -190,7 +207,7 @@ public void loadStorageLayer() throws IOException, InvalidConfigException {
public void loadFeatureFlag(List<TenantIdentifier> tenantsThatChanged) {
List<AppIdentifier> apps = new ArrayList<>();
Set<AppIdentifier> appsSet = new HashSet<>();
for (TenantConfig t : tenantConfigs) {
for (TenantConfig t : this.tenantConfigs) {
if (appsSet.contains(t.tenantIdentifier.toAppIdentifier())) {
continue;
}
Expand All @@ -204,7 +221,7 @@ public void loadSigningKeys(List<TenantIdentifier> tenantsThatChanged)
throws UnsupportedJWTSigningAlgorithmException {
List<AppIdentifier> apps = new ArrayList<>();
Set<AppIdentifier> appsSet = new HashSet<>();
for (TenantConfig t : tenantConfigs) {
for (TenantConfig t : this.tenantConfigs) {
if (appsSet.contains(t.tenantIdentifier.toAppIdentifier())) {
continue;
}
Expand Down Expand Up @@ -238,4 +255,21 @@ public TenantConfig[] getAllTenants() {
throw new IllegalStateException(e);
}
}

private TenantConfig[] getFilteredTenantConfigs(TenantConfig[] inputTenantConfigs) {
String loadOnlyCUD = Config.getBaseConfig(main).getSuperTokensLoadOnlyCUD();

if (loadOnlyCUD == null) {
return inputTenantConfigs;
}

return Arrays.stream(inputTenantConfigs)
.filter(tenantConfig -> tenantConfig.tenantIdentifier.getConnectionUriDomain().equals(loadOnlyCUD)
|| tenantConfig.tenantIdentifier.getConnectionUriDomain().equals(TenantIdentifier.DEFAULT_CONNECTION_URI))
.toArray(TenantConfig[]::new);
}

public boolean isConnectionUriDomainPresentInDb(String cud) {
return this.dangerous_allCUDsFromDb.contains(cud);
}
}
29 changes: 18 additions & 11 deletions src/main/java/io/supertokens/webserver/WebserverAPI.java
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import io.supertokens.config.CoreConfig;
import io.supertokens.exceptions.QuitProgramException;
import io.supertokens.featureflag.exceptions.FeatureNotEnabledException;
import io.supertokens.multitenancy.MultitenancyHelper;
import io.supertokens.multitenancy.exception.BadPermissionException;
import io.supertokens.output.Logging;
import io.supertokens.pluginInterface.Storage;
Expand Down Expand Up @@ -287,15 +288,18 @@ private String getConnectionUriDomain(HttpServletRequest req) throws ServletExce
String connectionUriDomain = req.getServerName();
connectionUriDomain = Utils.normalizeAndValidateConnectionUriDomain(connectionUriDomain, false);

try {
if (Config.getConfig(new TenantIdentifier(connectionUriDomain, null, null), main) ==
Config.getConfig(new TenantIdentifier(null, null, null), main)) {
return null;
if (MultitenancyHelper.getInstance(main).isConnectionUriDomainPresentInDb(connectionUriDomain)) {
CoreConfig baseConfig = Config.getBaseConfig(main);
if (baseConfig.getSuperTokensLoadOnlyCUD() != null) {
if (!connectionUriDomain.equals(baseConfig.getSuperTokensLoadOnlyCUD())) {
throw new ServletException(new BadRequestException("Connection URI domain is disallowed"));
}
}
} catch (TenantOrAppNotFoundException e) {
throw new IllegalStateException(e);

return connectionUriDomain;
}
return connectionUriDomain;

return null;
}

@TestOnly
Expand Down Expand Up @@ -481,10 +485,13 @@ protected void service(HttpServletRequest req, HttpServletResponse resp) throws
}
Logging.info(main, tenantIdentifier, "API ended: " + req.getRequestURI() + ". Method: " + req.getMethod(),
false);
try {
RequestStats.getInstance(main, tenantIdentifier.toAppIdentifier()).updateRequestStats();
} catch (TenantOrAppNotFoundException e) {
// Ignore the error as we would have already sent the response for tenantNotFound

if (tenantIdentifier != null) {
try {
RequestStats.getInstance(main, tenantIdentifier.toAppIdentifier()).updateRequestStats();
} catch (TenantOrAppNotFoundException e) {
// Ignore the error as we would have already sent the response for tenantNotFound
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
import com.google.gson.JsonElement;
import com.google.gson.JsonObject;
import io.supertokens.Main;
import io.supertokens.config.Config;
import io.supertokens.config.CoreConfig;
import io.supertokens.featureflag.exceptions.FeatureNotEnabledException;
import io.supertokens.multitenancy.Multitenancy;
import io.supertokens.multitenancy.exception.BadPermissionException;
Expand Down Expand Up @@ -49,6 +51,14 @@ protected void handle(HttpServletRequest req, TenantIdentifier sourceTenantIdent
HttpServletResponse resp)
throws ServletException, IOException {

CoreConfig baseConfig = Config.getBaseConfig(main);
if (baseConfig.getSuperTokensLoadOnlyCUD() != null) {
if (!(targetTenantIdentifier.getConnectionUriDomain().equals(TenantIdentifier.DEFAULT_CONNECTION_URI) || targetTenantIdentifier.getConnectionUriDomain().equals(baseConfig.getSuperTokensLoadOnlyCUD()))) {
throw new ServletException(new BadRequestException("Creation of connection uri domain or app or " +
"tenant is disallowed"));
}
}

TenantConfig tenantConfig = Multitenancy.getTenantInfo(main,
new TenantIdentifier(targetTenantIdentifier.getConnectionUriDomain(), targetTenantIdentifier.getAppId(),
targetTenantIdentifier.getTenantId()));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1914,6 +1914,7 @@ public void testAllConflictingConfigs() throws Exception {
"argon2_memory_kb",
"argon2_parallelism",
"bcrypt_log_rounds",
"supertokens_saas_load_only_cud"
};
Object[] disallowedValues = new Object[]{
3567, // port
Expand All @@ -1930,6 +1931,7 @@ public void testAllConflictingConfigs() throws Exception {
87795, // argon2_memory_kb
2, // argon2_parallelism
11, // bcrypt_log_rounds
"mydomain.com", // supertokens_saas_load_only_cud
};

process.kill();
Expand Down Expand Up @@ -1995,7 +1997,7 @@ public void testAllConflictingConfigs() throws Exception {
new Object[]{true, false}, // disable_telemetry
new Object[]{"BCRYPT", "ARGON2"}, // password_hashing_alg
new Object[]{"abcd1234abcd1234abcd1234abcd1234", "qwer1234qwer1234qwer1234qwer1234"}, // firebase_password_hashing_signer_key
new Object[]{"2.21", "3.0"} // supertokens_max_cdi_version
new Object[]{"2.21", "3.0"}, // supertokens_max_cdi_version
};

for (int i=0; i<conflictingInSameUserPool.length; i++) {
Expand Down
Loading

0 comments on commit a7ccacb

Please sign in to comment.