Skip to content

Commit

Permalink
Handle differing APIs across Java versions
Browse files Browse the repository at this point in the history
Signed-off-by: Thomas Farr <[email protected]>
  • Loading branch information
Xtansia committed Nov 5, 2024
1 parent d837e5c commit 101e90d
Show file tree
Hide file tree
Showing 3 changed files with 187 additions and 59 deletions.
59 changes: 27 additions & 32 deletions java-client/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -367,40 +367,35 @@ publishing {
}

if (runtimeJavaVersion >= JavaVersion.VERSION_11) {
val java11: SourceSet = sourceSets.create("java11") {
java {
compileClasspath += sourceSets.main.get().output + sourceSets.test.get().output
runtimeClasspath += sourceSets.main.get().output + sourceSets.test.get().output
srcDir("src/test/java11")
val java11: SourceSet = sourceSets.create("java11") {
java {
compileClasspath += sourceSets.main.get().output + sourceSets.test.get().output
runtimeClasspath += sourceSets.main.get().output + sourceSets.test.get().output
srcDir("src/test/java11")
}
}

configurations[java11.implementationConfigurationName].extendsFrom(configurations.testImplementation.get())
configurations[java11.runtimeOnlyConfigurationName].extendsFrom(configurations.testRuntimeOnly.get())

dependencies {
testImplementation("org.opensearch.test", "framework", opensearchVersion) {
exclude(group = "org.hamcrest")
}
}
}

configurations[java11.implementationConfigurationName].extendsFrom(configurations.testImplementation.get())
configurations[java11.runtimeOnlyConfigurationName].extendsFrom(configurations.testRuntimeOnly.get())
tasks.named<JavaCompile>("compileJava11Java") {
targetCompatibility = JavaVersion.VERSION_11.toString()
sourceCompatibility = JavaVersion.VERSION_11.toString()
}

tasks.named<JavaCompile>("compileTestJava") {
targetCompatibility = JavaVersion.VERSION_11.toString()
sourceCompatibility = JavaVersion.VERSION_11.toString()
}

dependencies {
testImplementation("org.opensearch.test", "framework", opensearchVersion) {
exclude(group = "org.hamcrest")
tasks.withType<Test> {
testClassesDirs += java11.output.classesDirs
classpath = sourceSets["java11"].runtimeClasspath
}
}

tasks.named<JavaCompile>("compileJava11Java") {
targetCompatibility = JavaVersion.VERSION_11.toString()
sourceCompatibility = JavaVersion.VERSION_11.toString()
}

tasks.named<JavaCompile>("compileTestJava") {
targetCompatibility = JavaVersion.VERSION_11.toString()
sourceCompatibility = JavaVersion.VERSION_11.toString()
}

tasks.named<Test>("integrationTest") {
testClassesDirs += java11.output.classesDirs
classpath = sourceSets["java11"].runtimeClasspath
}

tasks.named<Test>("unitTest") {
testClassesDirs += java11.output.classesDirs
classpath = sourceSets["java11"].runtimeClasspath
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -61,10 +61,21 @@
import software.amazon.awssdk.regions.Region;
import software.amazon.awssdk.utils.AttributeMap;

import javax.net.ssl.SSLContext;

@RunWith(Parameterized.class)
public class AwsSdk2TransportTests {
private static final Region TEST_REGION = Region.AP_SOUTHEAST_2;
private static final String TEST_INDEX = "sample-index1";
private static final SSLContext SSL_CONTEXT;

static {
try {
SSL_CONTEXT = GeneratedCertificateSSLContext.generate();
} catch (Exception e) {
throw new RuntimeException(e);
}
}

private HttpAsyncServer server;
private FunnellingHttpsProxy proxy;
Expand Down Expand Up @@ -121,7 +132,7 @@ public void setup() throws Exception {
.resolveAuthority(RequestRouter.LOCAL_AUTHORITY_RESOLVER)
.build()
)
.setTlsStrategy(new BasicClientTlsStrategy(GeneratedCertificateSSLContext.generate()))
.setTlsStrategy(new BasicClientTlsStrategy(SSL_CONTEXT))
.create();
server.start();
var serverAddress = (InetSocketAddress) server.listen(new InetSocketAddress(0), URIScheme.HTTPS).get().getAddress();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,26 +9,30 @@
package org.opensearch.client.transport.util;

import java.io.IOException;
import java.security.InvalidKeyException;
import java.lang.invoke.MethodHandles;
import java.lang.invoke.MethodType;
import java.security.KeyManagementException;
import java.security.KeyPair;
import java.security.KeyPairGenerator;
import java.security.KeyStore;
import java.security.KeyStoreException;
import java.security.NoSuchAlgorithmException;
import java.security.NoSuchProviderException;
import java.security.PrivateKey;
import java.security.PublicKey;
import java.security.SecureRandom;
import java.security.SignatureException;
import java.security.UnrecoverableKeyException;
import java.security.cert.CertificateException;
import java.security.cert.X509Certificate;
import java.time.ZonedDateTime;
import java.util.Date;
import java.util.Random;
import java.util.Vector;
import java.util.function.BiConsumer;
import java.util.function.Function;
import javax.net.ssl.SSLContext;
import org.apache.hc.core5.ssl.SSLContextBuilder;
import org.opensearch.client.util.TriConsumer;
import org.opensearch.client.util.TriFunction;
import sun.security.util.KnownOIDs;
import sun.security.util.ObjectIdentifier;
import sun.security.x509.AlgorithmId;
Expand All @@ -42,6 +46,7 @@
import sun.security.x509.CertificateX509Key;
import sun.security.x509.DNSName;
import sun.security.x509.ExtendedKeyUsageExtension;
import sun.security.x509.Extension;
import sun.security.x509.GeneralName;
import sun.security.x509.GeneralNames;
import sun.security.x509.IPAddressName;
Expand All @@ -59,12 +64,12 @@ public class GeneratedCertificateSSLContext {
private static final String SIGNATURE_ALGORITHM = "SHA256with" + KEY_ALGORITHM;
private static final char[] KEYSTORE_PASSWORD = "password".toCharArray();

public static SSLContext generate() throws NoSuchAlgorithmException, IOException, CertificateException, SignatureException,
InvalidKeyException, NoSuchProviderException, KeyStoreException, KeyManagementException, UnrecoverableKeyException {
public static SSLContext generate() throws NoSuchAlgorithmException, IOException, CertificateException, KeyStoreException,
KeyManagementException, UnrecoverableKeyException {
var caKey = generateKeyPair();
var caSubject = "DC=localhost, O=localhost, OU=localhost Root CA, CN=localhost Root CA";
var caInfo = makeX509CertInfo(caSubject, caKey.getPublic(), caSubject, makeCertificateAuthorityExtensions(caKey.getPublic()));
var caCert = X509CertImpl.newSigned(caInfo, caKey.getPrivate(), SIGNATURE_ALGORITHM);
var caCert = newSigned.apply(caInfo, caKey.getPrivate(), SIGNATURE_ALGORITHM);

var hostKey = generateKeyPair();
var hostSubject = "DC=localhost, O=localhost, OU=localhost, CN=localhost";
Expand All @@ -74,7 +79,7 @@ public static SSLContext generate() throws NoSuchAlgorithmException, IOException
caSubject,
makeHostCertificateExtensions(hostKey.getPublic(), caCert)
);
var hostCert = X509CertImpl.newSigned(hostInfo, caKey.getPrivate(), SIGNATURE_ALGORITHM);
var hostCert = newSigned.apply(hostInfo, caKey.getPrivate(), SIGNATURE_ALGORITHM);

var keyStore = KeyStore.getInstance("JKS");
keyStore.load(null, KEYSTORE_PASSWORD);
Expand All @@ -90,24 +95,25 @@ private static CertificateExtensions makeCertificateAuthorityExtensions(PublicKe
var extensions = new CertificateExtensions();

var keyId = new KeyIdentifier(publicKey);
extensions.setExtension(AuthorityKeyIdentifierExtension.NAME, new AuthorityKeyIdentifierExtension(keyId, null, null));
extensions.setExtension(SubjectKeyIdentifierExtension.NAME, new SubjectKeyIdentifierExtension(keyId.getIdentifier()));
set(extensions, AuthorityKeyIdentifierExtension.NAME, new AuthorityKeyIdentifierExtension(keyId, null, null));
set(extensions, SubjectKeyIdentifierExtension.NAME, new SubjectKeyIdentifierExtension(keyId.getIdentifier()));

extensions.setExtension(BasicConstraintsExtension.NAME, new BasicConstraintsExtension(true, true, Integer.MAX_VALUE));
set(extensions, BasicConstraintsExtension.NAME, new BasicConstraintsExtension(true, true, Integer.MAX_VALUE));

var keyUsage = new KeyUsageExtension();
keyUsage.set(KeyUsageExtension.DIGITAL_SIGNATURE, true);
keyUsage.set(KeyUsageExtension.KEY_CERTSIGN, true);
keyUsage.set(KeyUsageExtension.CRL_SIGN, true);
extensions.setExtension(KeyUsageExtension.NAME, keyUsage);
set(extensions, KeyUsageExtension.NAME, keyUsage);

return extensions;
}

private static CertificateExtensions makeHostCertificateExtensions(PublicKey publicKey, X509Certificate caCert) throws IOException {
var extensions = new CertificateExtensions();

extensions.setExtension(
set(
extensions,
AuthorityKeyIdentifierExtension.NAME,
new AuthorityKeyIdentifierExtension(
new KeyIdentifier(caCert.getPublicKey()),
Expand All @@ -117,22 +123,23 @@ private static CertificateExtensions makeHostCertificateExtensions(PublicKey pub
);

var keyId = new KeyIdentifier(publicKey);
extensions.setExtension(SubjectKeyIdentifierExtension.NAME, new SubjectKeyIdentifierExtension(keyId.getIdentifier()));
set(extensions, SubjectKeyIdentifierExtension.NAME, new SubjectKeyIdentifierExtension(keyId.getIdentifier()));

extensions.setExtension(BasicConstraintsExtension.NAME, new BasicConstraintsExtension(false, Integer.MAX_VALUE));
set(extensions, BasicConstraintsExtension.NAME, new BasicConstraintsExtension(false, Integer.MAX_VALUE));

var keyUsage = new KeyUsageExtension();
keyUsage.set(KeyUsageExtension.DIGITAL_SIGNATURE, true);
keyUsage.set(KeyUsageExtension.NON_REPUDIATION, true);
keyUsage.set(KeyUsageExtension.KEY_ENCIPHERMENT, true);
extensions.setExtension(KeyUsageExtension.NAME, keyUsage);
set(extensions, KeyUsageExtension.NAME, keyUsage);

var extendedKeyUsage = new Vector<ObjectIdentifier>();
extendedKeyUsage.add(ObjectIdentifier.of(KnownOIDs.clientAuth));
extendedKeyUsage.add(ObjectIdentifier.of(KnownOIDs.serverAuth));
extensions.setExtension(ExtendedKeyUsageExtension.NAME, new ExtendedKeyUsageExtension(true, extendedKeyUsage));
set(extensions, ExtendedKeyUsageExtension.NAME, new ExtendedKeyUsageExtension(true, extendedKeyUsage));

extensions.setExtension(
set(
extensions,
SubjectAlternativeNameExtension.NAME,
new SubjectAlternativeNameExtension(
new GeneralNames().add(new GeneralName(new DNSName("localhost"))).add(new GeneralName(new IPAddressName("127.0.0.1")))
Expand All @@ -147,19 +154,19 @@ private static X509CertInfo makeX509CertInfo(
PublicKey publicKey,
String issuer,
CertificateExtensions certificateExtensions
) throws IOException, NoSuchAlgorithmException, CertificateException {
) throws IOException, NoSuchAlgorithmException {
var start = ZonedDateTime.now().minusDays(1);
var end = start.plusDays(7);

var info = new X509CertInfo();
info.setVersion(new CertificateVersion(CertificateVersion.V3));
info.setSerialNumber(new CertificateSerialNumber(new Random().nextInt() & 0x7fffffff));
info.setAlgorithmId(new CertificateAlgorithmId(AlgorithmId.get(SIGNATURE_ALGORITHM)));
info.setSubject(new X500Name(subject));
info.setKey(new CertificateX509Key(publicKey));
info.setValidity(new CertificateValidity(Date.from(start.toInstant()), Date.from(end.toInstant())));
info.setIssuer(new X500Name(issuer));
info.setExtensions(certificateExtensions);
setVersion.accept(info, new CertificateVersion(CertificateVersion.V3));
setSerialNumber.accept(info, new CertificateSerialNumber(new Random().nextInt() & 0x7fffffff));
setAlgorithmId.accept(info, new CertificateAlgorithmId(AlgorithmId.get(SIGNATURE_ALGORITHM)));
setSubject.accept(info, new X500Name(subject));
setKey.accept(info, new CertificateX509Key(publicKey));
setValidity.accept(info, new CertificateValidity(Date.from(start.toInstant()), Date.from(end.toInstant())));
setIssuer.accept(info, new X500Name(issuer));
setExtensions.accept(info, certificateExtensions);

return info;
}
Expand All @@ -169,4 +176,119 @@ private static KeyPair generateKeyPair() throws NoSuchAlgorithmException {
keyGen.initialize(2048, new SecureRandom());
return keyGen.generateKeyPair();
}

private static void set(CertificateExtensions extensions, String name, Extension value) {
setExtension.accept(extensions, name, value);
}

private static final TriConsumer<CertificateExtensions, String, Extension> setExtension;
private static final BiConsumer<X509CertInfo, CertificateVersion> setVersion;
private static final BiConsumer<X509CertInfo, CertificateSerialNumber> setSerialNumber;
private static final BiConsumer<X509CertInfo, CertificateAlgorithmId> setAlgorithmId;
private static final BiConsumer<X509CertInfo, X500Name> setSubject;
private static final BiConsumer<X509CertInfo, CertificateX509Key> setKey;
private static final BiConsumer<X509CertInfo, CertificateValidity> setValidity;
private static final BiConsumer<X509CertInfo, X500Name> setIssuer;
private static final BiConsumer<X509CertInfo, CertificateExtensions> setExtensions;
private static final TriFunction<X509CertInfo, PrivateKey, String, X509CertImpl> newSigned;

static {
try {
if (Runtime.version().compareTo(Runtime.Version.parse("20")) >= 0) {
setExtension = findVoidMethod(CertificateExtensions.class, "setExtension", String.class, Extension.class);
setVersion = findVoidMethod(X509CertInfo.class, "setVersion", CertificateVersion.class);
setSerialNumber = findVoidMethod(X509CertInfo.class, "setSerialNumber", CertificateSerialNumber.class);
setAlgorithmId = findVoidMethod(X509CertInfo.class, "setAlgorithmId", CertificateAlgorithmId.class);
setSubject = findVoidMethod(X509CertInfo.class, "setSubject", X500Name.class);
setKey = findVoidMethod(X509CertInfo.class, "setKey", CertificateX509Key.class);
setValidity = findVoidMethod(X509CertInfo.class, "setValidity", CertificateValidity.class);
setIssuer = findVoidMethod(X509CertInfo.class, "setIssuer", X500Name.class);
setExtensions = findVoidMethod(X509CertInfo.class, "setExtensions", CertificateExtensions.class);
newSigned = findStaticMethod(
X509CertImpl.class,
"newSigned",
X509CertImpl.class,
X509CertInfo.class,
PrivateKey.class,
String.class
);
} else {
setExtension = findVoidMethod(CertificateExtensions.class, "set", String.class, Object.class)::accept;
var setCertInfo = findVoidMethod(X509CertInfo.class, "set", String.class, Object.class);
setVersion = (info, version) -> setCertInfo.accept(info, X509CertInfo.VERSION, version);
setSerialNumber = (info, serialNumber) -> setCertInfo.accept(info, X509CertInfo.SERIAL_NUMBER, serialNumber);
setAlgorithmId = (info, algorithmId) -> setCertInfo.accept(info, X509CertInfo.ALGORITHM_ID, algorithmId);
setSubject = (info, subject) -> setCertInfo.accept(info, X509CertInfo.SUBJECT, subject);
setKey = (info, key) -> setCertInfo.accept(info, X509CertInfo.KEY, key);
setValidity = (info, validity) -> setCertInfo.accept(info, X509CertInfo.VALIDITY, validity);
setIssuer = (info, issuer) -> setCertInfo.accept(info, X509CertInfo.ISSUER, issuer);
setExtensions = (info, extensions) -> setCertInfo.accept(info, X509CertInfo.EXTENSIONS, extensions);

var x509CertImplCtor = findCtor(X509CertImpl.class, X509CertInfo.class);
var sign = findVoidMethod(X509CertImpl.class, "sign", PrivateKey.class, String.class);
newSigned = (info, privateKey, algorithm) -> {
var cert = x509CertImplCtor.apply(info);
sign.accept(cert, privateKey, algorithm);
return cert;
};
}
} catch (Throwable e) {
throw new RuntimeException(e);
}
}

private static <This, A> Function<A, This> findCtor(Class<This> clazz, Class<A> aType) throws Throwable {
var handle = MethodHandles.lookup().findConstructor(clazz, MethodType.methodType(void.class, aType));
return (a) -> {
try {
//noinspection unchecked
return (This) handle.invoke(a);
} catch (Throwable e) {
throw new RuntimeException(e);
}
};
}

private static <Return, A, B, C> TriFunction<A, B, C, Return> findStaticMethod(
Class<?> clazz,
String name,
Class<Return> rType,
Class<A> aType,
Class<B> bType,
Class<C> cType
) throws Throwable {
var handle = MethodHandles.lookup().findStatic(clazz, name, MethodType.methodType(rType, aType, bType, cType));
return (a, b, c) -> {
try {
//noinspection unchecked
return (Return) handle.invoke(a, b, c);
} catch (Throwable e) {
throw new RuntimeException(e);
}
};
}

private static <This, A> BiConsumer<This, A> findVoidMethod(Class<This> clazz, String name, Class<A> aType) throws Throwable {
var handle = MethodHandles.lookup().findVirtual(clazz, name, MethodType.methodType(void.class, aType));
// noinspection unchecked
return (thiz, a) -> {
try {
handle.bindTo(thiz).invoke(a);
} catch (Throwable e) {
throw new RuntimeException(e);
}
};
}

private static <This, A, B> TriConsumer<This, A, B> findVoidMethod(Class<This> clazz, String name, Class<A> aType, Class<B> bType)
throws Throwable {
var handle = MethodHandles.lookup().findVirtual(clazz, name, MethodType.methodType(void.class, aType, bType));
return (thiz, a, b) -> {
try {
handle.bindTo(thiz).invoke(a, b);
} catch (Throwable e) {
throw new RuntimeException(e);
}
};
}
}

0 comments on commit 101e90d

Please sign in to comment.