Skip to content

Commit

Permalink
fix: SKFP-1025 add compression of token before encrypt it (#298)
Browse files Browse the repository at this point in the history
  • Loading branch information
celinepelletier authored Jun 17, 2024
1 parent 470111e commit a302074
Show file tree
Hide file tree
Showing 9 changed files with 162 additions and 21 deletions.
17 changes: 14 additions & 3 deletions src/main/java/io/kidsfirst/core/service/AwsKmsService.java
Original file line number Diff line number Diff line change
Expand Up @@ -43,17 +43,23 @@ public Mono<String> encrypt(String original) {
return ByteBufferToString(bufferedCipher);

} catch (UnsupportedEncodingException e) {
// Shouldn't be reachable, handle anyways
// Shouldn't be reachable, handle anyway
log.error(e.getMessage(), e);
return null;
} catch (AWSKMSException e) {
log.error("AWSKMSException occurs when encrypting [{}] with message {}", original, e.getMessage());
log.error("AWSKMSException occurs when encrypting with message {}", e.getMessage());
return null;
}
}).subscribeOn(Schedulers.boundedElastic());

}

@Override
public Mono<String> compressAndEncrypt(String original) {
String compressedOriginal = StringCompressService.compress(original);
return encrypt(compressedOriginal);
}

public Mono<String> decrypt(String cipher) {
return Mono.fromCallable(() -> {
try {
Expand All @@ -66,14 +72,19 @@ public Mono<String> decrypt(String cipher) {
return ByteBufferToString(bufferedOriginal);

} catch (UnsupportedEncodingException e) {
// Shouldn't be reachable, handle anyways
// Shouldn't be reachable, handle anyway
log.error(e.getMessage(), e);
return null;
}
}).subscribeOn(Schedulers.boundedElastic());

}

@Override
public Mono<String> decryptAndDecompress(String cipher) {
return decrypt(cipher).map(StringCompressService::decompress);
}

private ByteBuffer StringToByteBuffer(String string) throws UnsupportedEncodingException {
val bytes = string.getBytes(StandardCharsets.ISO_8859_1);
return ByteBuffer.wrap(bytes);
Expand Down
2 changes: 2 additions & 0 deletions src/main/java/io/kidsfirst/core/service/KmsService.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
public interface KmsService {

Mono<String> encrypt(String original);
Mono<String> compressAndEncrypt(String original);

Mono<String> decrypt(String cipher);
Mono<String> decryptAndDecompress(String cipher);
}
11 changes: 10 additions & 1 deletion src/main/java/io/kidsfirst/core/service/MockKmsService.java
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
package io.kidsfirst.core.service;

import io.kidsfirst.core.service.KmsService;
import org.springframework.context.annotation.Profile;
import org.springframework.stereotype.Service;
import reactor.core.publisher.Mono;
Expand All @@ -14,8 +13,18 @@ public Mono<String> encrypt(String original) {
return Mono.just("encrypted_" + original);
}

@Override
public Mono<String> compressAndEncrypt(String original) {
return Mono.just("encrypted_compressed_" + original);
}

@Override
public Mono<String> decrypt(String cipher) {
return Mono.just(cipher.replaceFirst("encrypted_", ""));
}

@Override
public Mono<String> decryptAndDecompress(String cipher) {
return Mono.just(cipher.replaceFirst("decompressed_encrypted_", ""));
}
}
24 changes: 19 additions & 5 deletions src/main/java/io/kidsfirst/core/service/SecretService.java
Original file line number Diff line number Diff line change
Expand Up @@ -36,29 +36,29 @@ public Mono<Secret> deleteSecret(String service, String userId) {
}

public Mono<String> fetchAccessToken(final AllFences.Fence fence, final String userId) {
return fetchAndDecrypt(userId, fence.keyAccessToken());
return fetchDecryptAndDecompress(userId, fence.keyAccessToken());
}

public Mono<String> fetchRefreshToken(final AllFences.Fence fence, final String userId) {
return fetchAndDecrypt(userId, fence.keyRefreshToken());
return fetchDecryptAndDecompress(userId, fence.keyRefreshToken());
}

public Mono<Secret> persistAccessToken(final AllFences.Fence fence, final String userId, final String token, Long expiration) {
val secret = new Secret(userId, fence.keyAccessToken(), token, expiration);
return encryptAndSave(secret);
return compressEncryptAndSave(secret);
}

public Mono<Secret> persistRefreshToken(final AllFences.Fence fence, final String userId, final String token, Long expiration, boolean isNew) {
//For refresh token, expiration date is set only the first time
if(isNew){
val secret = new Secret(userId, fence.keyRefreshToken(), token, expiration);
return encryptAndSave(secret);
return compressEncryptAndSave(secret);
}
val existingSecret = Mono.fromFuture(secretDao.getSecret(fence.keyRefreshToken(), userId));
return existingSecret.map(s -> s.getExpiration() != null ? s.getExpiration() : expiration).defaultIfEmpty(expiration)
.flatMap(exp -> {
val secret = new Secret(userId, fence.keyRefreshToken(), token, exp);
return encryptAndSave(secret);
return compressEncryptAndSave(secret);
});

}
Expand Down Expand Up @@ -103,12 +103,26 @@ public Mono<Secret> encryptAndSave(final Secret secret) {
.flatMap(s -> Mono.fromFuture(secretDao.saveOrUpdateSecret(s)));
}

public Mono<Secret> compressEncryptAndSave(final Secret secret) {
val secretValue = secret.getSecret();
val encryptedValue = kmsService.compressAndEncrypt(secretValue);
return encryptedValue
.mapNotNull(s -> new Secret(secret.getUserId(), secret.getService(), s, secret.getExpiration()))
.flatMap(s -> Mono.fromFuture(secretDao.saveOrUpdateSecret(s)));
}

public Mono<String> fetchAndDecrypt(final String userId, final String service) {
val secret = getSecret(service, userId);
return secret.mapNotNull(s -> s).flatMap(s -> kmsService.decrypt(s.getSecret()));

}

public Mono<String> fetchDecryptAndDecompress(final String userId, final String service) {
val secret = getSecret(service, userId);
return secret.mapNotNull(s -> s).flatMap(s -> kmsService.decryptAndDecompress(s.getSecret()));

}

public Mono<String> fetchAndDecryptNotExpired(final String userId, final String service) {
return getSecret(service, userId)
.filter(Secret::notExpired)
Expand Down
48 changes: 48 additions & 0 deletions src/main/java/io/kidsfirst/core/service/StringCompressService.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
package io.kidsfirst.core.service;

import lombok.extern.slf4j.Slf4j;

import java.io.*;
import java.nio.charset.StandardCharsets;
import java.util.zip.GZIPInputStream;
import java.util.zip.GZIPOutputStream;

@Slf4j
public class StringCompressService {

public static String compress(String str) {
if (str == null || str.isEmpty()) {
return str;
}

ByteArrayOutputStream out = new ByteArrayOutputStream();

try (GZIPOutputStream gzip = new GZIPOutputStream(out)) {
gzip.write(str.getBytes());
gzip.close();
return out.toString(StandardCharsets.ISO_8859_1);
} catch (IOException e) {
log.error("Error during string compress", e);
return str;
}
}

public static String decompress(String str) {
if (str == null || str.isEmpty()) {
return str;
}

try(GZIPInputStream gis = new GZIPInputStream(new ByteArrayInputStream(str.getBytes(StandardCharsets.ISO_8859_1)));
BufferedReader bf = new BufferedReader(new InputStreamReader(gis, "ISO_8859_1"))) {
StringBuilder outStr = new StringBuilder();
String line;
while ((line = bf.readLine()) != null) {
outStr.append(line);
}
return outStr.toString();
} catch (IOException e) {
log.error("Error during string decompress", e);
return str;
}
}
}
12 changes: 6 additions & 6 deletions src/test/java/io/kidsfirst/keys/DynamicProxyTests.java
Original file line number Diff line number Diff line change
Expand Up @@ -168,13 +168,13 @@ void testProxyWithBothTokenAndRefreshValid() throws ExecutionException, Interrup
//Verify than access token has been refreshed
val accessSecret = secretTable.getItem(new Secret(userIdAndToken.getUserId(), "fence_gen3_access", null, null)).get();
assertThat(accessSecret).isNotNull();
assertThat(accessSecret.getSecret()).isEqualTo("encrypted_this_is_a_fresh_access_token");
assertThat(accessSecret.getSecret()).isEqualTo("encrypted_compressed_this_is_a_fresh_access_token");
assertThat(accessSecret.notExpired()).isTrue();

//Verify than refresh token has been refreshed, except for expiration date
val refreshSecret = secretTable.getItem(new Secret(userIdAndToken.getUserId(), "fence_gen3_refresh", null, null)).get();
assertThat(refreshSecret).isNotNull();
assertThat(refreshSecret.getSecret()).isEqualTo("encrypted_this_is_a_fresh_refresh_token");
assertThat(refreshSecret.getSecret()).isEqualTo("encrypted_compressed_this_is_a_fresh_refresh_token");
assertThat(refreshSecret.getExpiration()).isEqualTo(expirationRefresh);

}
Expand Down Expand Up @@ -208,13 +208,13 @@ void testProxyWithOnlyRefreshTokenValid() throws ExecutionException, Interrupted
//Verify than access token has been refreshed
val accessSecret = secretTable.getItem(new Secret(userIdAndToken.getUserId(), "fence_gen3_access", null, null)).get();
assertThat(accessSecret).isNotNull();
assertThat(accessSecret.getSecret()).isEqualTo("encrypted_this_is_a_fresh_access_token");
assertThat(accessSecret.getSecret()).isEqualTo("encrypted_compressed_this_is_a_fresh_access_token");
assertThat(accessSecret.notExpired()).isTrue();

//Verify than refresh token has been refreshed, except for expiration date
val refreshSecret = secretTable.getItem(new Secret(userIdAndToken.getUserId(), "fence_gen3_refresh", null, null)).get();
assertThat(refreshSecret).isNotNull();
assertThat(refreshSecret.getSecret()).isEqualTo("encrypted_this_is_a_fresh_refresh_token");
assertThat(refreshSecret.getSecret()).isEqualTo("encrypted_compressed_this_is_a_fresh_refresh_token");
assertThat(refreshSecret.getExpiration()).isEqualTo(expirationRefresh);
}

Expand Down Expand Up @@ -248,13 +248,13 @@ void testProxyWithRefreshTokenValidAndAccessWithoutExpiration() throws Execution
//Verify than access token has been refreshed
val accessSecret = secretTable.getItem(new Secret(userIdAndToken.getUserId(), "fence_gen3_access", null, null)).get();
assertThat(accessSecret).isNotNull();
assertThat(accessSecret.getSecret()).isEqualTo("encrypted_this_is_a_fresh_access_token");
assertThat(accessSecret.getSecret()).isEqualTo("encrypted_compressed_this_is_a_fresh_access_token");
assertThat(accessSecret.notExpired()).isTrue();

//Verify than refresh token has been refreshed, except for expiration date
val refreshSecret = secretTable.getItem(new Secret(userIdAndToken.getUserId(), "fence_gen3_refresh", null, null)).get();
assertThat(refreshSecret).isNotNull();
assertThat(refreshSecret.getSecret()).isEqualTo("encrypted_this_is_a_fresh_refresh_token");
assertThat(refreshSecret.getSecret()).isEqualTo("encrypted_compressed_this_is_a_fresh_refresh_token");
assertThat(refreshSecret.getExpiration()).isEqualTo(expirationRefresh);
}

Expand Down
8 changes: 4 additions & 4 deletions src/test/java/io/kidsfirst/keys/FenceDeprecatedTests.java
Original file line number Diff line number Diff line change
Expand Up @@ -127,11 +127,11 @@ void testFenceRefreshPOST() throws Exception {

val accessSecret = secretTable.getItem(new Secret(userIdAndToken.getUserId(), "fence_gen3_access", null, null)).get();
assertThat(accessSecret).isNotNull();
assertThat(accessSecret.getSecret()).isEqualTo("encrypted_this_is_access_token");
assertThat(accessSecret.getSecret()).isEqualTo("encrypted_compressed_this_is_access_token");

val refreshSecret = secretTable.getItem(new Secret(userIdAndToken.getUserId(), "fence_gen3_refresh", null, null)).get();
assertThat(refreshSecret).isNotNull();
assertThat(refreshSecret.getSecret()).isEqualTo("encrypted_this_is_a_fresh_refresh_token");
assertThat(refreshSecret.getSecret()).isEqualTo("encrypted_compressed_this_is_a_fresh_refresh_token");
assertThat(refreshSecret.getExpiration()).isEqualTo(expirationRefresh);

}
Expand Down Expand Up @@ -220,12 +220,12 @@ void testFenceTokenPOST() throws Exception {

val accessSecret = secretTable.getItem(new Secret(userIdAndToken.getUserId(), "fence_gen3_access", null, null)).get();
assertThat(accessSecret).isNotNull();
assertThat(accessSecret.getSecret()).isEqualTo("encrypted_this_is_fresh_access_token");
assertThat(accessSecret.getSecret()).isEqualTo("encrypted_compressed_this_is_fresh_access_token");
assertThat(accessSecret.getExpiration()).isGreaterThan(expiration);

val refreshSecret = secretTable.getItem(new Secret(userIdAndToken.getUserId(), "fence_gen3_refresh", null, null)).get();
assertThat(refreshSecret).isNotNull();
assertThat(refreshSecret.getSecret()).isEqualTo("encrypted_this_is_fresh_refresh_token");
assertThat(refreshSecret.getSecret()).isEqualTo("encrypted_compressed_this_is_fresh_refresh_token");
assertThat(refreshSecret.getExpiration()).isGreaterThan(expiration);


Expand Down
4 changes: 2 additions & 2 deletions src/test/java/io/kidsfirst/keys/FenceTests.java
Original file line number Diff line number Diff line change
Expand Up @@ -113,13 +113,13 @@ void testFenceTokenExchange() throws Exception {

val accessSecret = secretTable.getItem(new Secret(userIdAndToken.getUserId(), "fence_gen3_access", null, null)).get();
assertThat(accessSecret).isNotNull();
assertThat(accessSecret.getSecret()).isEqualTo("encrypted_this_is_fresh_access_token");
assertThat(accessSecret.getSecret()).isEqualTo("encrypted_compressed_this_is_fresh_access_token");
assertThat(accessSecret.notExpired()).isTrue();
assertThat(accessSecret.getExpiration()).isGreaterThan(expiration);

val refreshSecret = secretTable.getItem(new Secret(userIdAndToken.getUserId(), "fence_gen3_refresh", null, null)).get();
assertThat(refreshSecret).isNotNull();
assertThat(refreshSecret.getSecret()).isEqualTo("encrypted_this_is_fresh_refresh_token");
assertThat(refreshSecret.getSecret()).isEqualTo("encrypted_compressed_this_is_fresh_refresh_token");
assertThat(refreshSecret.notExpired()).isTrue();
assertThat(refreshSecret.getExpiration()).isGreaterThan(expiration);

Expand Down
Loading

0 comments on commit a302074

Please sign in to comment.