Skip to content

Commit

Permalink
Apply field predicate before searching type hierarchy
Browse files Browse the repository at this point in the history
This commit includes a fix with two simple test classes that
demonstrate the issue.

TODO:

- add "formal" tests in ReflectionUtilsTests and AnnotationUtilsTests
- add release note entries

See junit-team#3498
Closes junit-team#3532
  • Loading branch information
sbrannen committed Nov 3, 2023
1 parent 0a514b9 commit 569e2dc
Show file tree
Hide file tree
Showing 3 changed files with 98 additions and 17 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
/*
* Copyright 2015-2023 the original author or authors.
*
* All rights reserved. This program and the accompanying materials are
* made available under the terms of the Eclipse Public License v2.0 which
* accompanies this distribution and is available at
*
* https://www.eclipse.org/legal/epl-v20.html
*/

package demo.a;

import static org.assertj.core.api.Assertions.assertThat;

import java.nio.file.Path;

import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir;

public class SuperclassTempDirTests {

@TempDir
static Path tempDir;

protected static Path getStaticTempDir() {
return tempDir;
}

@Test
void superTest() {
assertThat(getStaticTempDir()).exists();
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
/*
* Copyright 2015-2023 the original author or authors.
*
* All rights reserved. This program and the accompanying materials are
* made available under the terms of the Eclipse Public License v2.0 which
* accompanies this distribution and is available at
*
* https://www.eclipse.org/legal/epl-v20.html
*/

package demo.b;

import static org.assertj.core.api.Assertions.assertThat;

import java.nio.file.Path;

import demo.a.SuperclassTempDirTests;

import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir;

class SubclassTempDirTests extends SuperclassTempDirTests {

@TempDir
Path tempDir;

Path getInstanceTempDir() {
return this.tempDir;
}

@Test
void subTest() {
assertThat(getInstanceTempDir()).exists();
assertThat(getStaticTempDir()).exists();
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -1238,6 +1238,7 @@ public static List<Constructor<?>> findConstructors(Class<?> clazz, Predicate<Co
*/
public static List<Field> findFields(Class<?> clazz, Predicate<Field> predicate,
HierarchyTraversalMode traversalMode) {

return streamFields(clazz, predicate, traversalMode).collect(toUnmodifiableList());
}

Expand All @@ -1252,21 +1253,23 @@ public static Stream<Field> streamFields(Class<?> clazz, Predicate<Field> predic
Preconditions.notNull(predicate, "Predicate must not be null");
Preconditions.notNull(traversalMode, "HierarchyTraversalMode must not be null");

return findAllFieldsInHierarchy(clazz, traversalMode).stream().filter(predicate);
return findAllFieldsInHierarchy(clazz, predicate, traversalMode).stream();
}

private static List<Field> findAllFieldsInHierarchy(Class<?> clazz, HierarchyTraversalMode traversalMode) {
private static List<Field> findAllFieldsInHierarchy(Class<?> clazz, Predicate<Field> predicate,
HierarchyTraversalMode traversalMode) {

Preconditions.notNull(clazz, "Class must not be null");
Preconditions.notNull(traversalMode, "HierarchyTraversalMode must not be null");

// @formatter:off
List<Field> localFields = getDeclaredFields(clazz).stream()
List<Field> localFields = getDeclaredFields(clazz, predicate).stream()
.filter(field -> !field.isSynthetic())
.collect(toList());
List<Field> superclassFields = getSuperclassFields(clazz, traversalMode).stream()
List<Field> superclassFields = getSuperclassFields(clazz, predicate, traversalMode).stream()
.filter(field -> !isFieldShadowedByLocalFields(field, localFields))
.collect(toList());
List<Field> interfaceFields = getInterfaceFields(clazz, traversalMode).stream()
List<Field> interfaceFields = getInterfaceFields(clazz, predicate, traversalMode).stream()
.filter(field -> !isFieldShadowedByLocalFields(field, localFields))
.collect(toList());
// @formatter:on
Expand Down Expand Up @@ -1529,18 +1532,20 @@ private static List<Method> findAllMethodsInHierarchy(Class<?> clazz, Predicate<

/**
* Custom alternative to {@link Class#getFields()} that sorts the fields
* and converts them to a mutable list.
* which match the supplied predicate and converts them to a mutable list.
* @param predicate the field filter; never {@code null}
*/
private static List<Field> getFields(Class<?> clazz) {
return toSortedMutableList(clazz.getFields());
private static List<Field> getFields(Class<?> clazz, Predicate<Field> predicate) {
return toSortedMutableList(clazz.getFields(), predicate);
}

/**
* Custom alternative to {@link Class#getDeclaredFields()} that sorts the
* fields and converts them to a mutable list.
* fields which match the supplied predicate and converts them to a mutable list.
* @param predicate the field filter; never {@code null}
*/
private static List<Field> getDeclaredFields(Class<?> clazz) {
return toSortedMutableList(clazz.getDeclaredFields());
private static List<Field> getDeclaredFields(Class<?> clazz, Predicate<Field> predicate) {
return toSortedMutableList(clazz.getDeclaredFields(), predicate);
}

/**
Expand Down Expand Up @@ -1602,9 +1607,10 @@ private static List<Method> getDefaultMethods(Class<?> clazz) {
// @formatter:on
}

private static List<Field> toSortedMutableList(Field[] fields) {
private static List<Field> toSortedMutableList(Field[] fields, Predicate<Field> predicate) {
// @formatter:off
return Arrays.stream(fields)
.filter(predicate)
.sorted(ReflectionUtils::defaultFieldSorter)
// Use toCollection() instead of toList() to ensure list is mutable.
.collect(toCollection(ArrayList::new));
Expand Down Expand Up @@ -1672,13 +1678,15 @@ private static List<Method> getInterfaceMethods(Class<?> clazz, Predicate<Method
return allInterfaceMethods;
}

private static List<Field> getInterfaceFields(Class<?> clazz, HierarchyTraversalMode traversalMode) {
private static List<Field> getInterfaceFields(Class<?> clazz, Predicate<Field> predicate,
HierarchyTraversalMode traversalMode) {

List<Field> allInterfaceFields = new ArrayList<>();
for (Class<?> ifc : clazz.getInterfaces()) {
List<Field> localInterfaceFields = getFields(ifc);
List<Field> localInterfaceFields = getFields(ifc, predicate);

// @formatter:off
List<Field> superinterfaceFields = getInterfaceFields(ifc, traversalMode).stream()
List<Field> superinterfaceFields = getInterfaceFields(ifc, predicate, traversalMode).stream()
.filter(field -> !isFieldShadowedByLocalFields(field, localInterfaceFields))
.collect(toList());
// @formatter:on
Expand All @@ -1694,12 +1702,14 @@ private static List<Field> getInterfaceFields(Class<?> clazz, HierarchyTraversal
return allInterfaceFields;
}

private static List<Field> getSuperclassFields(Class<?> clazz, HierarchyTraversalMode traversalMode) {
private static List<Field> getSuperclassFields(Class<?> clazz, Predicate<Field> predicate,
HierarchyTraversalMode traversalMode) {

Class<?> superclass = clazz.getSuperclass();
if (!isSearchable(superclass)) {
return Collections.emptyList();
}
return findAllFieldsInHierarchy(superclass, traversalMode);
return findAllFieldsInHierarchy(superclass, predicate, traversalMode);
}

private static boolean isFieldShadowedByLocalFields(Field field, List<Field> localFields) {
Expand Down

0 comments on commit 569e2dc

Please sign in to comment.