Skip to content

Commit

Permalink
Improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
rickie authored and Stephan202 committed Oct 30, 2022
1 parent 2aa9b26 commit 329e315
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
import com.google.errorprone.VisitorState;
import com.google.errorprone.bugpatterns.BugChecker;
import com.google.errorprone.bugpatterns.BugChecker.MethodTreeMatcher;
import com.google.errorprone.fixes.Fix;
import com.google.errorprone.fixes.SuggestedFix;
import com.google.errorprone.matchers.Description;
import com.google.errorprone.matchers.Matcher;
Expand All @@ -35,17 +34,21 @@
import com.sun.source.tree.MethodInvocationTree;
import com.sun.source.tree.MethodTree;
import com.sun.source.tree.ReturnTree;
import com.sun.source.tree.Tree;
import com.sun.source.tree.Tree.Kind;
import com.sun.tools.javac.code.Type;
import java.util.Optional;
import java.util.stream.Stream;
import javax.lang.model.type.TypeKind;
import tech.picnic.errorprone.bugpatterns.util.SourceCode;

/**
* A {@link BugChecker} that flags JUnit tests with {@link
* org.junit.jupiter.params.provider.MethodSource} that can be written as a {@link
* org.junit.jupiter.params.provider.ValueSource}.
*/
// XXX: Support rewriting when there are multiple sources defined for the `@MethodSource`, iff
// applicable.
// XXX: Don't remove factory methods that are used by another `@MethodSource`.
@AutoService(BugChecker.class)
@BugPattern(
summary =
Expand Down Expand Up @@ -77,15 +80,20 @@ public Description matchMethod(MethodTree tree, VisitorState state) {
AnnotationTree annotationTree =
ASTHelpers.getAnnotationWithSimpleName(
tree.getModifiers().getAnnotations(), "MethodSource");
Optional<Fix> fix = tryConstructValueSourceFix(parameterType, annotationTree, state);

return fix.isPresent() ? describeMatch(tree, fix.orElseThrow()) : Description.NO_MATCH;
return tryConstructValueSourceFix(parameterType, annotationTree, state)
.map(fix -> describeMatch(tree, fix.build()))
.orElse(Description.NO_MATCH);
}

private static Optional<Fix> tryConstructValueSourceFix(
private static Optional<SuggestedFix.Builder> tryConstructValueSourceFix(
Type parameterType, AnnotationTree methodSourceAnnotation, VisitorState state) {
String factoryMethodName = extractFactoryMethodName(methodSourceAnnotation);
MethodTree factoryMethod = getFactoryMethod(factoryMethodName, state);
Optional<String> factoryMethodName = extractSingleFactoryMethodName(methodSourceAnnotation);
if (factoryMethodName.isEmpty()) {
/* `@MethodSource` defines more than one source. */
return Optional.empty();
}
MethodTree factoryMethod = findFactoryMethod(factoryMethodName.orElseThrow(), state);

Optional<String> valueSourceAttributeValue =
getReturnTree(factoryMethod)
Expand All @@ -104,8 +112,7 @@ private static Optional<Fix> tryConstructValueSourceFix(
String.format(
"@ValueSource(%s = {%s})",
toValueSourceAttributeName(parameterType.toString()), attributeValue))
.delete(factoryMethod)
.build());
.delete(factoryMethod));
}

private static Optional<ReturnTree> getReturnTree(MethodTree methodTree) {
Expand All @@ -115,14 +122,18 @@ private static Optional<ReturnTree> getReturnTree(MethodTree methodTree) {
.map(ReturnTree.class::cast);
}

private static String extractFactoryMethodName(AnnotationTree methodSourceAnnotation) {
ExpressionTree expression =
private static Optional<String> extractSingleFactoryMethodName(
AnnotationTree methodSourceAnnotation) {
ExpressionTree attributeExpression =
((AssignmentTree) Iterables.getOnlyElement(methodSourceAnnotation.getArguments()))
.getExpression();
return ASTHelpers.getType(expression).stringValue();
Type attributeType = ASTHelpers.getType(attributeExpression);
return attributeType.getKind() == TypeKind.ARRAY
? Optional.empty()
: Optional.of(attributeType.stringValue());
}

private static MethodTree getFactoryMethod(String factoryMethodName, VisitorState state) {
private static MethodTree findFactoryMethod(String factoryMethodName, VisitorState state) {
return state.findEnclosing(ClassTree.class).getMembers().stream()
.filter(MethodTree.class::isInstance)
.map(MethodTree.class::cast)
Expand All @@ -133,7 +144,7 @@ private static MethodTree getFactoryMethod(String factoryMethodName, VisitorStat

private static Optional<String> extractArgumentsFromStream(
MethodInvocationTree tree, VisitorState state) {
ImmutableList<String> args =
ImmutableList<String> arguments =
tree.getArguments().stream()
.filter(MethodInvocationTree.class::isInstance)
.map(MethodInvocationTree.class::cast)
Expand All @@ -143,10 +154,10 @@ private static Optional<String> extractArgumentsFromStream(
.collect(toImmutableList());

/* Not all values are compile-time constants. */
if (args.size() != tree.getArguments().size()) {
if (arguments.size() != tree.getArguments().size()) {
return Optional.empty();
}
return Optional.of(String.join(", ", args));
return Optional.of(String.join(", ", arguments));
}

private static String toValueSourceAttributeName(String type) {
Expand All @@ -163,7 +174,7 @@ private static String toValueSourceAttributeName(String type) {
}

private static boolean isCompileTimeConstant(ExpressionTree argument) {
return argument.getKind() == Tree.Kind.MEMBER_SELECT
return argument.getKind() == Kind.MEMBER_SELECT
? ((MemberSelectTree) argument).getIdentifier().contentEquals("class")
: ASTHelpers.constValue(argument) != null;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@ final class JUnitValueSourceTest {
private final BugCheckerRefactoringTestHelper refactoringTestHelper =
BugCheckerRefactoringTestHelper.newInstance(JUnitValueSource.class, getClass());

// XXX: Add a test case for when a factory is used by more than one test.

@Test
void identificationChar() {
compilationTestHelper
Expand Down Expand Up @@ -182,7 +180,38 @@ void identificationNoRuntimeParameters() {
}

@Test
void identificationConstantValues() {
void identificationDontFlagForMultipleFactories() {
compilationTestHelper
.addSourceLines(
"A.java",
"import static org.assertj.core.api.Assertions.assertThat;",
"import static org.junit.jupiter.params.provider.Arguments.arguments;",
"",
"import java.util.stream.Stream;",
"import org.junit.jupiter.params.ParameterizedTest;",
"import org.junit.jupiter.params.provider.Arguments;",
"import org.junit.jupiter.params.provider.MethodSource;",
"",
"class A {",
" @ParameterizedTest",
" @MethodSource({\"fooTestCases\", \"barTestCases\"})",
" void foo(int i) {",
" assertThat(i).isNotNull();",
" }",
"",
" private static Stream<Arguments> fooTestCases() {",
" return Stream.of(arguments(1));",
" }",
"",
" private static Stream<Arguments> barTestCases() {",
" return Stream.of(arguments(1));",
" }",
"}")
.doTest();
}

@Test
void identificationOnlyConstantValues() {
compilationTestHelper
.addSourceLines(
"A.java",
Expand Down

0 comments on commit 329e315

Please sign in to comment.