Skip to content

Commit

Permalink
add Key Encapsulation Method to EncryptionAlgorithm
Browse files Browse the repository at this point in the history
moar cleanup
  • Loading branch information
Hellblazer committed Jan 8, 2024
1 parent c2d42a8 commit 6b35ab4
Show file tree
Hide file tree
Showing 4 changed files with 77 additions and 10 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
package com.salesforce.apollo.cryptography;

import javax.crypto.DecapsulateException;
import javax.crypto.KEM;
import javax.crypto.SecretKey;
import java.math.BigInteger;
import java.security.*;
import java.security.interfaces.EdECPrivateKey;
Expand All @@ -21,10 +24,16 @@ public String curveName() {
return "Curve25519";
}

@Override
public int getCode() {
return 1;
}

@Override
public int publicKeyLength() {
return 32;
}

}, X_448 {
@Override
public String algorithmName() {
Expand All @@ -36,15 +45,33 @@ public String curveName() {
return "Curve448";
}

@Override
public int getCode() {
return 2;
}

@Override
public int publicKeyLength() {
return 57;
}
};

public static final EncryptionAlgorithm DEFAULT = X_25519;
public static final String DHKEM = "DHKEM";
public static final String XDH = "XDH";

public static EncryptionAlgorithm lookup(int code) {
return switch (code) {
case 0 -> throw new IllegalArgumentException("Uninitialized enum value");
case 1 -> X_25519;
case 2 -> X_448;
default -> throw new IllegalArgumentException("Unknown code: " + code);
};
}

public static EncryptionAlgorithm lookup(PrivateKey privateKey) {
return switch (privateKey.getAlgorithm()) {
case "XDH" -> lookupX(((EdECPrivateKey) privateKey).getParams());
case XDH -> lookupX(((EdECPrivateKey) privateKey).getParams());
case "x25519" -> X_25519;
case "x448" -> X_448;
default -> throw new IllegalArgumentException("Unknown algorithm: " + privateKey.getAlgorithm());
Expand All @@ -53,7 +80,7 @@ public static EncryptionAlgorithm lookup(PrivateKey privateKey) {

public static EncryptionAlgorithm lookup(PublicKey publicKey) {
return switch (publicKey.getAlgorithm()) {
case "XDH" -> lookupX(((EdECPublicKey) publicKey).getParams());
case XDH -> lookupX(((EdECPublicKey) publicKey).getParams());
case "X25519" -> X_25519;
case "X448" -> X_448;
default -> throw new IllegalArgumentException("Unknown algorithm: " + publicKey.getAlgorithm());
Expand All @@ -73,13 +100,31 @@ private static EncryptionAlgorithm lookupX(NamedParameterSpec params) {

abstract public String curveName();

final public SecretKey decapsulate(PrivateKey privateKey, byte[] encapsulated, String algorithm) {
try {
var kem = KEM.getInstance(DHKEM);
return kem.newDecapsulator(privateKey).decapsulate(encapsulated, 0, encapsulated.length, algorithm);
} catch (NoSuchAlgorithmException | InvalidKeyException | DecapsulateException e) {
throw new IllegalArgumentException("Invalid public key", e);
}
}

final public KEM.Encapsulated encapsulated(PublicKey publicKey) {
try {
var kem = KEM.getInstance(DHKEM);
return kem.newEncapsulator(publicKey).encapsulate();
} catch (NoSuchAlgorithmException | InvalidKeyException e) {
throw new IllegalArgumentException("Invalid public key", e);
}
}

final public byte[] encode(PublicKey publicKey) {
return ((XECPublicKey) publicKey).getU().toByteArray();
}

final public KeyPair generateKeyPair() {
try {
KeyPairGenerator kpg = KeyPairGenerator.getInstance("XDH");
KeyPairGenerator kpg = KeyPairGenerator.getInstance(XDH);
kpg.initialize(getParamSpec());
return kpg.generateKeyPair();
} catch (NoSuchAlgorithmException | InvalidAlgorithmParameterException e) {
Expand All @@ -89,17 +134,19 @@ final public KeyPair generateKeyPair() {

final public KeyPair generateKeyPair(SecureRandom secureRandom) {
try {
KeyPairGenerator kpg = KeyPairGenerator.getInstance("XDH");
KeyPairGenerator kpg = KeyPairGenerator.getInstance(XDH);
kpg.initialize(getParamSpec(), secureRandom);
return kpg.generateKeyPair();
} catch (NoSuchAlgorithmException | InvalidAlgorithmParameterException e) {
throw new IllegalArgumentException("Cannot generate key pair", e);
}
}

abstract public int getCode();

final public PublicKey publicKey(byte[] bytes) {
try {
KeyFactory kf = KeyFactory.getInstance("XDH");
KeyFactory kf = KeyFactory.getInstance(XDH);
BigInteger u = new BigInteger(bytes);
XECPublicKeySpec pubSpec = new XECPublicKeySpec(getParamSpec(), u);
return kf.generatePublic(pubSpec);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,27 @@ public void testEncoding() throws Exception {
}

@Test
public void testRoundTrip() throws Exception {
public void testKEM() throws Exception {
var entropy = SecureRandom.getInstance("SHA1PRNG");
entropy.setSeed(new byte[] { 6, 6, 6 });

var algorithm = EncryptionAlgorithm.X_25519;
var pair1 = algorithm.generateKeyPair(entropy);
assertNotNull(pair1);
var pair2 = algorithm.generateKeyPair(entropy);
assertNotNull(pair2);

var encapsulated = algorithm.encapsulated(pair2.getPublic());
assertNotNull(encapsulated);

var secretKey = algorithm.decapsulate(pair2.getPrivate(), encapsulated.encapsulation(), "AES");

assertNotNull(secretKey);
assertEquals(encapsulated.key(), secretKey);
}

@Test
public void testKeyAgreement() throws Exception {
var entropy = SecureRandom.getInstance("SHA1PRNG");
entropy.setSeed(new byte[] { 6, 6, 6 });

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ public DemesneImpl(DemesneParameters parameters) throws GeneralSecurityException
entropy.nextBytes(pwd);
final var password = Hex.hexChars(pwd);
final Supplier<char[]> passwordProvider = () -> password;
final var keystore = KeyStore.getInstance("JKS");
final var keystore = KeyStore.getInstance("JCEKS");

keystore.load(null, password);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -217,9 +217,9 @@ public void before() throws Exception {
routers.add(localRouter);
var dbUrl = String.format("jdbc:h2:mem:sql-%s-%s;DB_CLOSE_DELAY=-1", member.getId(), UUID.randomUUID());
var pdParams = new ProcessDomain.ProcessDomainParameters(dbUrl, Duration.ofMinutes(1),
"jdbc:h2:mem:%s-state".formatted(d),
checkpointDirBase, Duration.ofMillis(10), 0.00125,
Duration.ofMinutes(1), 3, 10, 0.1);
"jdbc:h2:mem:%s-state;DB_CLOSE_DELAY=-1".formatted(
d), checkpointDirBase, Duration.ofMillis(10),
0.00125, Duration.ofMinutes(1), 3, 10, 0.1);
var domain = new ProcessDomain(group, member, pdParams, params, RuntimeParameters.newBuilder()
.setFoundation(sealed)
.setContext(context)
Expand Down

0 comments on commit 6b35ab4

Please sign in to comment.