diff --git a/junit-jupiter-engine/src/test/java/a/A.java b/junit-jupiter-engine/src/test/java/a/A.java new file mode 100644 index 000000000000..580e04bb7619 --- /dev/null +++ b/junit-jupiter-engine/src/test/java/a/A.java @@ -0,0 +1,28 @@ +/* + * Copyright 2015-2023 the original author or authors. + * + * All rights reserved. This program and the accompanying materials are + * made available under the terms of the Eclipse Public License v2.0 which + * accompanies this distribution and is available at + * + * https://www.eclipse.org/legal/epl-v20.html + */ + +package a; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; + +import org.junit.jupiter.api.BeforeAll; + +public abstract class A { + + public static final List invocations = Collections.synchronizedList(new ArrayList<>()); + + @BeforeAll + static void before() { + invocations.add("A.before()"); + } + +} diff --git a/junit-jupiter-engine/src/test/java/b/B.java b/junit-jupiter-engine/src/test/java/b/B.java new file mode 100644 index 000000000000..1fb5aa55549d --- /dev/null +++ b/junit-jupiter-engine/src/test/java/b/B.java @@ -0,0 +1,38 @@ +/* + * Copyright 2015-2023 the original author or authors. + * + * All rights reserved. This program and the accompanying materials are + * made available under the terms of the Eclipse Public License v2.0 which + * accompanies this distribution and is available at + * + * https://www.eclipse.org/legal/epl-v20.html + */ + +package b; + +import static org.assertj.core.api.Assertions.assertThat; + +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import a.A; + +public class B extends A { + + @BeforeEach + void before() { + invocations.add("B.before()"); + } + + @Test + void test() { + invocations.add("B.test()"); + } + + @AfterAll + static void checkInvocations() { + assertThat(A.invocations).containsExactly("A.before()", "B.before()", "B.test()"); + } + +} diff --git a/junit-platform-commons/src/main/java/org/junit/platform/commons/util/ReflectionUtils.java b/junit-platform-commons/src/main/java/org/junit/platform/commons/util/ReflectionUtils.java index 816b7fadd104..927b97451e0a 100644 --- a/junit-platform-commons/src/main/java/org/junit/platform/commons/util/ReflectionUtils.java +++ b/junit-platform-commons/src/main/java/org/junit/platform/commons/util/ReflectionUtils.java @@ -1489,29 +1489,27 @@ public static Stream streamMethods(Class clazz, Predicate pre Preconditions.notNull(predicate, "Predicate must not be null"); Preconditions.notNull(traversalMode, "HierarchyTraversalMode must not be null"); - // @formatter:off - return findAllMethodsInHierarchy(clazz, traversalMode).stream() - .filter(predicate) - .distinct(); - // @formatter:on + return findAllMethodsInHierarchy(clazz, predicate, traversalMode).stream().distinct(); } /** * Find all non-synthetic methods in the superclass and interface hierarchy, - * excluding Object. + * excluding Object, that match the specified {@code predicate}. */ - private static List findAllMethodsInHierarchy(Class clazz, HierarchyTraversalMode traversalMode) { + private static List findAllMethodsInHierarchy(Class clazz, Predicate predicate, + HierarchyTraversalMode traversalMode) { + Preconditions.notNull(clazz, "Class must not be null"); Preconditions.notNull(traversalMode, "HierarchyTraversalMode must not be null"); // @formatter:off List localMethods = getDeclaredMethods(clazz, traversalMode).stream() - .filter(method -> !method.isSynthetic()) + .filter(predicate.and(method -> !method.isSynthetic())) .collect(toList()); - List superclassMethods = getSuperclassMethods(clazz, traversalMode).stream() + List superclassMethods = getSuperclassMethods(clazz, predicate, traversalMode).stream() .filter(method -> !isMethodShadowedByLocalMethods(method, localMethods)) .collect(toList()); - List interfaceMethods = getInterfaceMethods(clazz, traversalMode).stream() + List interfaceMethods = getInterfaceMethods(clazz, predicate, traversalMode).stream() .filter(method -> !isMethodShadowedByLocalMethods(method, localMethods)) .collect(toList()); // @formatter:on @@ -1647,16 +1645,18 @@ private static int defaultMethodSorter(Method method1, Method method2) { return comparison; } - private static List getInterfaceMethods(Class clazz, HierarchyTraversalMode traversalMode) { + private static List getInterfaceMethods(Class clazz, Predicate predicate, + HierarchyTraversalMode traversalMode) { + List allInterfaceMethods = new ArrayList<>(); for (Class ifc : clazz.getInterfaces()) { // @formatter:off List localInterfaceMethods = getMethods(ifc).stream() - .filter(m -> !isAbstract(m)) + .filter(predicate.and(method -> !isAbstract(method))) .collect(toList()); - List superinterfaceMethods = getInterfaceMethods(ifc, traversalMode).stream() + List superinterfaceMethods = getInterfaceMethods(ifc, predicate, traversalMode).stream() .filter(method -> !isMethodShadowedByLocalMethods(method, localInterfaceMethods)) .collect(toList()); // @formatter:on @@ -1706,12 +1706,14 @@ private static boolean isFieldShadowedByLocalFields(Field field, List loc return localFields.stream().anyMatch(local -> local.getName().equals(field.getName())); } - private static List getSuperclassMethods(Class clazz, HierarchyTraversalMode traversalMode) { + private static List getSuperclassMethods(Class clazz, Predicate predicate, + HierarchyTraversalMode traversalMode) { + Class superclass = clazz.getSuperclass(); if (!isSearchable(superclass)) { return Collections.emptyList(); } - return findAllMethodsInHierarchy(superclass, traversalMode); + return findAllMethodsInHierarchy(superclass, predicate, traversalMode); } private static boolean isMethodShadowedByLocalMethods(Method method, List localMethods) {