Skip to content

Commit

Permalink
Additional test extensions
Browse files Browse the repository at this point in the history
  • Loading branch information
kwakeroni committed Nov 15, 2018
1 parent 8b08920 commit 60e02ac
Show file tree
Hide file tree
Showing 10 changed files with 852 additions and 0 deletions.
5 changes: 5 additions & 0 deletions parameters-core/parameters-test-support/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,11 @@
<artifactId>junit-jupiter-api</artifactId>
<scope>compile</scope>
</dependency>
<dependency>
<groupId>junit</groupId>
<artifactId>junit</artifactId>
<scope>compile</scope>
</dependency>
<dependency>
<groupId>org.mockito</groupId>
<artifactId>mockito-core</artifactId>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
package be.kwakeroni.test.extension;

import org.junit.jupiter.api.extension.AfterTestExecutionCallback;
import org.junit.jupiter.api.extension.BeforeTestExecutionCallback;
import org.junit.jupiter.api.extension.ExtensionContext;

import java.io.File;
import java.io.IOException;
import java.io.UncheckedIOException;
import java.lang.annotation.Retention;
import java.lang.annotation.Target;
import java.net.URL;
import java.net.URLClassLoader;
import java.nio.file.Path;

import static java.lang.annotation.ElementType.FIELD;
import static java.lang.annotation.ElementType.PARAMETER;
import static java.lang.annotation.RetentionPolicy.RUNTIME;

public class ContextClassLoaderExtension extends ExtensionSupport implements BeforeTestExecutionCallback, AfterTestExecutionCallback {


@Override
public void beforeTestExecution(ExtensionContext context) {

URL[] urls =
getFields(context.getRequiredTestInstance())
.filter(annotatedBy(AddToClasspath.class))
.map(InstanceField::get)
.map(this::toURL)
.toArray(URL[]::new);


ClassLoader original = Thread.currentThread().getContextClassLoader();
State.ORIGINAL_CLASSLOADER.set(context, original);
Thread.currentThread().setContextClassLoader(new URLClassLoader(urls, original));
}

private URL toURL(Object o) {
try {
if (o instanceof File) {
return ((File) o).toURI().toURL();
} else if (o instanceof Path) {
return ((Path) o).toUri().toURL();
} else if (o instanceof URL) {
return (URL) o;
} else if (o instanceof String) {
return new URL((String) o);
} else {
throw new IllegalStateException("Could not convert to URL: " + o);
}
} catch (IOException exc) {
throw new UncheckedIOException(exc);
}
}

@Override
public void afterTestExecution(ExtensionContext context) {
ClassLoader original = State.ORIGINAL_CLASSLOADER.get(context, ClassLoader.class);
Thread.currentThread().setContextClassLoader(original);
}

@Target({FIELD, PARAMETER})
@Retention(RUNTIME)
public static @interface AddToClasspath {
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
package be.kwakeroni.test.extension;

import org.junit.jupiter.api.extension.ExtensionContext;

import java.lang.annotation.Annotation;
import java.lang.reflect.AnnotatedElement;
import java.lang.reflect.Field;
import java.util.Arrays;
import java.util.List;
import java.util.Optional;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.function.Predicate;
import java.util.stream.Collectors;
import java.util.stream.Stream;

class ExtensionSupport {

protected Optional<Object> getParentTestInstance(ExtensionContext context) {
return context.getTestInstance()
.flatMap(this::getEncapsulatingInstance);
}

private Optional<Object> getEncapsulatingInstance(Object instance) {
Class<?> encapsClass = instance.getClass().getEnclosingClass();
if (encapsClass == null) {
return Optional.empty();
}
List<Field> fields = classFields(instance.getClass())
.filter(field -> field.getType() == encapsClass)
.collect(Collectors.toList());

if (fields.size() == 1) {
return Optional.of(getField(fields.get(0), instance));
}
if (fields.size() > 1) {
System.out.println("[WARN] Encountered multiple fields of encapsulating class: " + fields);
}

return Optional.empty();
}

protected Stream<InstanceField> getFields(Object testInstance) {
return recursiveStream(testInstance, this::getEncapsulatingInstance)
.flatMap(this::instanceFields);
}

private Stream<InstanceField> instanceFields(Object testInstance) {
return recursiveStream(testInstance.getClass(), this::getSuperClassBelowObject)
.flatMap(this::classFields)
.map(field -> new InstanceField(field, testInstance));
}

private Stream<Field> classFields(Class<?> type) {
return Arrays.stream(type.getDeclaredFields());
}

private Optional<Class<?>> getSuperClassBelowObject(Class<?> clazz) {
return Optional.ofNullable(clazz.getSuperclass())
.filter((Object o) -> o != Object.class)
.map(Function.identity());

}


private <T> Stream<T> recursiveStream(T seed, Function<T, Optional<T>> function) {
return recursiveStream(seed, Stream::of, function);
}


private <S, T> Stream<T> recursiveStream(S seed, Function<S, Stream<T>> streamSupplier, Function<S, Optional<S>> function) {
return Stream.of(seed)
.flatMap(s -> concatWithParents(s, streamSupplier, function));
}

private <S, T> Stream<T> concatWithParents(S seed, Function<S, Stream<T>> streamSupplier, Function<S, Optional<S>> function) {
return Stream.concat(
streamSupplier.apply(seed),
function.apply(seed)
.map(parent -> recursiveStream(parent, streamSupplier, function))
.orElseGet(Stream::empty));
}

protected static void setField(Field field, Object instance, Object value) {
try {
field.set(instance, value);
} catch (IllegalAccessException e) {
throw new IllegalStateException(e);
} catch (Exception exc) {
exc.printStackTrace();
}
}

private static Object getField(Field field, Object instance) {
try {
field.setAccessible(true);
return field.get(instance);
} catch (IllegalAccessException e) {
throw new IllegalStateException(e);
}
}

protected static Predicate<AnnotatedElement> annotatedBy(Class<? extends Annotation> annotationType) {
return element -> element.isAnnotationPresent(annotationType);
}

protected static Consumer<Field> inject(Object instance, Function<Class<?>, Object> value) {
return field -> setField(field, instance, value.apply(field.getType()));
}

protected static Consumer<Field> accept(Object instance, Consumer<Object> valueConsumer) {
return field -> valueConsumer.accept(getField(field, instance));
}


static class InstanceField implements AnnotatedElement {
final Field field;
final Object instance;

private InstanceField(Field field, Object instance) {
this.field = field;
this.instance = instance;
}

void set(Object value) {
setField(field, instance, value);
}

Object get() {
return getField(field, instance);
}

Class<?> getType() {
return field.getType();
}

String getName() {
return field.getName();
}


@Override
public boolean isAnnotationPresent(Class<? extends Annotation> annotationClass) {
return field.isAnnotationPresent(annotationClass);
}

@Override
public Annotation[] getAnnotations() {
return field.getAnnotations();
}

@Override
public <T extends Annotation> T getDeclaredAnnotation(Class<T> annotationClass) {
return field.getDeclaredAnnotation(annotationClass);
}

@Override
public <T extends Annotation> T[] getDeclaredAnnotationsByType(Class<T> annotationClass) {
return field.getDeclaredAnnotationsByType(annotationClass);
}

@Override
public <T extends Annotation> T getAnnotation(Class<T> annotationClass) {
return field.getAnnotation(annotationClass);
}

@Override
public Annotation[] getDeclaredAnnotations() {
return field.getDeclaredAnnotations();
}

@Override
public <T extends Annotation> T[] getAnnotationsByType(Class<T> annotationClass) {
return field.getAnnotationsByType(annotationClass);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
package be.kwakeroni.test.extension;

import org.junit.jupiter.api.extension.AfterAllCallback;
import org.junit.jupiter.api.extension.AfterEachCallback;
import org.junit.jupiter.api.extension.AfterTestExecutionCallback;
import org.junit.jupiter.api.extension.BeforeAllCallback;
import org.junit.jupiter.api.extension.BeforeEachCallback;
import org.junit.jupiter.api.extension.BeforeTestExecutionCallback;
import org.junit.jupiter.api.extension.Extension;
import org.junit.jupiter.api.extension.ExtensionContext;
import org.junit.jupiter.api.extension.TestInstancePostProcessor;

import java.util.Optional;

public class LifeCycleLogger implements Extension,
TestInstancePostProcessor,
BeforeAllCallback, BeforeEachCallback, BeforeTestExecutionCallback,
AfterTestExecutionCallback, AfterEachCallback, AfterAllCallback {

public static void log(String msg) {
System.out.println("[" + Thread.currentThread().getId() + "] " + msg);
}

@Override
public void afterAll(ExtensionContext context) {
log("After All " + context.getTestInstance());
}

@Override
public void afterEach(ExtensionContext context) {
log("After Each " + context.getTestInstance());
}

@Override
public void afterTestExecution(ExtensionContext context) {
log("After Exec " + context.getTestInstance());
}

@Override
public void beforeAll(ExtensionContext context) {
log("Before All " + context.getTestInstance());
}

@Override
public void beforeEach(ExtensionContext context) {
log("Before Each " + context.getTestInstance());
}

@Override
public void beforeTestExecution(ExtensionContext context) {
log("Before Exec " + context.getTestInstance());
}

@Override
public void postProcessTestInstance(Object testInstance, ExtensionContext context) {
log("Post Process " + Optional.ofNullable(testInstance));
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
package be.kwakeroni.test.extension;

import org.junit.jupiter.api.extension.ExtensionContext;

import java.util.function.Supplier;

enum State {
ORIGINAL_CLASSLOADER;

private static final ExtensionContext.Namespace NAMESPACE = ExtensionContext.Namespace.create("be.kwakeroni", "test");

public void set(ExtensionContext extensionContext, Object value) {
if (extensionContext.getStore(NAMESPACE).get(name()) != null) {
throw new IllegalStateException("Store not empty");
}
extensionContext.getStore(NAMESPACE).put(name(), value);
}

public <T> T get(ExtensionContext extensionContext) {
return (T) extensionContext.getStore(NAMESPACE).get(name());

}

public <T> T get(ExtensionContext extensionContext, Class<T> type) {
return extensionContext.getStore(NAMESPACE).get(name(), type);
}

public <T> T getOrCreate(ExtensionContext extensionContext, Supplier<T> supplier) {
return (T) extensionContext.getStore(NAMESPACE).getOrComputeIfAbsent(name(), k -> supplier.get());
}


}
Loading

0 comments on commit 60e02ac

Please sign in to comment.