From f61c1ca522ca7ff360fa2f7fc57ca13f174342ea Mon Sep 17 00:00:00 2001 From: Tim te Beek Date: Tue, 28 Apr 2026 21:36:52 +0200 Subject: [PATCH] Java: support parameterized types in `ChangeMethodInvocationReturnType` (#7502) Previously the recipe ran `JavaType.buildType(newReturnType)`, which folded the generic suffix into a malformed `ShallowClass` FQN and rebuilt the variable's type expression as a single `J.Identifier` whose simple name still contained `<...>`. Parse the new return type into a proper `TypeTree` (handling `Parameterized`, arrays, wildcards, primitives) and walk every `FullyQualified` node when adding/removing imports so type parameters are imported too. --- .../ChangeMethodInvocationReturnType.java | 190 +++++++++++++++--- .../ChangeMethodInvocationReturnTypeTest.java | 129 ++++++++++++ 2 files changed, 296 insertions(+), 23 deletions(-) diff --git a/rewrite-java/src/main/java/org/openrewrite/java/ChangeMethodInvocationReturnType.java b/rewrite-java/src/main/java/org/openrewrite/java/ChangeMethodInvocationReturnType.java index 622e1f791ff..1a7fdb371ce 100644 --- a/rewrite-java/src/main/java/org/openrewrite/java/ChangeMethodInvocationReturnType.java +++ b/rewrite-java/src/main/java/org/openrewrite/java/ChangeMethodInvocationReturnType.java @@ -17,14 +17,17 @@ import lombok.EqualsAndHashCode; import lombok.Value; +import org.jspecify.annotations.Nullable; import org.openrewrite.*; import org.openrewrite.internal.ListUtils; -import org.openrewrite.java.tree.J; -import org.openrewrite.java.tree.JavaType; -import org.openrewrite.java.tree.TypeUtils; +import org.openrewrite.java.tree.*; import org.openrewrite.marker.Markers; +import java.util.ArrayList; +import java.util.List; + import static java.util.Collections.emptyList; +import static org.openrewrite.Tree.randomId; @Value @EqualsAndHashCode(callSuper = false) @@ -36,7 +39,8 @@ public class ChangeMethodInvocationReturnType extends Recipe { String methodPattern; @Option(displayName = "New method invocation return type", - description = "The fully qualified new return type of method invocation.", + description = "The fully qualified new return type of method invocation. " + + "Parameterized types like `java.util.Set` are supported.", example = "long") String newReturnType; @@ -61,7 +65,8 @@ public J.MethodInvocation visitMethodInvocation(J.MethodInvocation method, Execu J.MethodInvocation m = super.visitMethodInvocation(method, ctx); JavaType.Method type = m.getMethodType(); if (methodMatcher.matches(method) && type != null && !newReturnType.equals(type.getReturnType().toString())) { - type = type.withReturnType(JavaType.buildType(newReturnType)); + JavaType newType = createTypeTree(newReturnType).getType(); + type = type.withReturnType(newType); m = m.withMethodType(type); if (m.getName().getType() != null) { m = m.withName(m.getName().withType(type)); @@ -74,27 +79,19 @@ public J.MethodInvocation visitMethodInvocation(J.MethodInvocation method, Execu @Override public J.VariableDeclarations visitVariableDeclarations(J.VariableDeclarations multiVariable, ExecutionContext ctx) { methodUpdated = false; - JavaType.FullyQualified originalType = multiVariable.getTypeAsFullyQualified(); + JavaType originalType = multiVariable.getType(); J.VariableDeclarations mv = super.visitVariableDeclarations(multiVariable, ctx); if (methodUpdated) { - JavaType newType = JavaType.buildType(newReturnType); - JavaType.FullyQualified newFieldType = TypeUtils.asFullyQualified(newType); - - maybeRemoveImport(originalType); - maybeAddImport(newFieldType); - - mv = mv.withTypeExpression(mv.getTypeExpression() == null ? - null : - new J.Identifier(mv.getTypeExpression().getId(), - mv.getTypeExpression().getPrefix(), - Markers.EMPTY, - emptyList(), - newReturnType.substring(newReturnType.lastIndexOf('.') + 1), - newType, - null - ) - ); + TypeTree newTypeTree = createTypeTree(newReturnType); + JavaType newType = newTypeTree.getType(); + + removeImportsForType(originalType); + addImportsForType(newType); + + if (mv.getTypeExpression() != null) { + mv = mv.withTypeExpression(newTypeTree.withPrefix(mv.getTypeExpression().getPrefix())); + } mv = mv.withVariables(ListUtils.map(mv.getVariables(), var -> { JavaType.FullyQualified varType = TypeUtils.asFullyQualified(var.getType()); @@ -107,6 +104,153 @@ public J.VariableDeclarations visitVariableDeclarations(J.VariableDeclarations m return mv; } + + private void addImportsForType(@Nullable JavaType type) { + if (type instanceof JavaType.Parameterized) { + JavaType.Parameterized parameterized = (JavaType.Parameterized) type; + maybeAddImport(parameterized.getType()); + for (JavaType param : parameterized.getTypeParameters()) { + addImportsForType(param); + } + } else if (type instanceof JavaType.Array) { + addImportsForType(((JavaType.Array) type).getElemType()); + } else if (type instanceof JavaType.FullyQualified) { + maybeAddImport((JavaType.FullyQualified) type); + } + } + + private void removeImportsForType(@Nullable JavaType type) { + if (type instanceof JavaType.Parameterized) { + JavaType.Parameterized parameterized = (JavaType.Parameterized) type; + maybeRemoveImport(parameterized.getType()); + for (JavaType param : parameterized.getTypeParameters()) { + removeImportsForType(param); + } + } else if (type instanceof JavaType.Array) { + removeImportsForType(((JavaType.Array) type).getElemType()); + } else if (type instanceof JavaType.FullyQualified) { + maybeRemoveImport((JavaType.FullyQualified) type); + } + } + + private TypeTree createTypeTree(String typeName) { + int arrayIndex = typeName.lastIndexOf('['); + if (arrayIndex != -1) { + TypeTree elementType = createTypeTree(typeName.substring(0, arrayIndex)); + return new J.ArrayType( + randomId(), + Space.EMPTY, + Markers.EMPTY, + elementType, + null, + JLeftPadded.build(Space.EMPTY), + new JavaType.Array(null, elementType.getType(), null) + ); + } + int genericsIndex = typeName.indexOf('<'); + if (genericsIndex != -1) { + TypeTree rawType = createTypeTree(typeName.substring(0, genericsIndex)); + List> typeParameters = new ArrayList<>(); + List typeParameterTypes = new ArrayList<>(); + List rawArgs = splitTypeArguments(typeName.substring(genericsIndex + 1, typeName.lastIndexOf('>'))); + for (int i = 0; i < rawArgs.size(); i++) { + TypeTree paramTree = createTypeTree(rawArgs.get(i).trim()); + if (i > 0) { + paramTree = paramTree.withPrefix(Space.SINGLE_SPACE); + } + typeParameters.add(JRightPadded.build((Expression) paramTree)); + typeParameterTypes.add(paramTree.getType()); + } + JavaType.FullyQualified rawFqn = TypeUtils.asFullyQualified(rawType.getType()); + return new J.ParameterizedType( + randomId(), + Space.EMPTY, + Markers.EMPTY, + rawType, + JContainer.build(Space.EMPTY, typeParameters, Markers.EMPTY), + new JavaType.Parameterized(null, rawFqn, typeParameterTypes) + ); + } + JavaType.Primitive primitive = JavaType.Primitive.fromKeyword(typeName); + if (primitive != null) { + return new J.Primitive( + randomId(), + Space.EMPTY, + Markers.EMPTY, + primitive + ); + } + if ("?".equals(typeName)) { + return new J.Wildcard( + randomId(), + Space.EMPTY, + Markers.EMPTY, + null, + null + ); + } + if (typeName.startsWith("?") && typeName.contains("extends")) { + return new J.Wildcard( + randomId(), + Space.EMPTY, + Markers.EMPTY, + new JLeftPadded<>(Space.SINGLE_SPACE, J.Wildcard.Bound.Extends, Markers.EMPTY), + createTypeTree(typeName.substring(typeName.indexOf("extends") + "extends".length() + 1).trim()).withPrefix(Space.SINGLE_SPACE) + ); + } + if (typeName.startsWith("?") && typeName.contains("super")) { + return new J.Wildcard( + randomId(), + Space.EMPTY, + Markers.EMPTY, + new JLeftPadded<>(Space.SINGLE_SPACE, J.Wildcard.Bound.Super, Markers.EMPTY), + createTypeTree(typeName.substring(typeName.indexOf("super") + "super".length() + 1).trim()).withPrefix(Space.SINGLE_SPACE) + ); + } + if (typeName.indexOf('.') == -1) { + String javaLangType = TypeUtils.findQualifiedJavaLangTypeName(typeName); + JavaType type = javaLangType != null ? + JavaType.buildType(javaLangType) : + JavaType.ShallowClass.build(typeName); + return new J.Identifier( + randomId(), + Space.EMPTY, + Markers.EMPTY, + emptyList(), + typeName, + type, + null + ); + } + return new J.Identifier( + randomId(), + Space.EMPTY, + Markers.EMPTY, + emptyList(), + typeName.substring(typeName.lastIndexOf('.') + 1), + JavaType.ShallowClass.build(typeName), + null + ); + } + + private List splitTypeArguments(String args) { + List result = new ArrayList<>(); + int depth = 0; + int start = 0; + for (int i = 0; i < args.length(); i++) { + char c = args.charAt(i); + if (c == '<') { + depth++; + } else if (c == '>') { + depth--; + } else if (c == ',' && depth == 0) { + result.add(args.substring(start, i)); + start = i + 1; + } + } + result.add(args.substring(start)); + return result; + } }; } } diff --git a/rewrite-java/src/test/java/org/openrewrite/java/ChangeMethodInvocationReturnTypeTest.java b/rewrite-java/src/test/java/org/openrewrite/java/ChangeMethodInvocationReturnTypeTest.java index a74340a33a4..ef0676260c8 100644 --- a/rewrite-java/src/test/java/org/openrewrite/java/ChangeMethodInvocationReturnTypeTest.java +++ b/rewrite-java/src/test/java/org/openrewrite/java/ChangeMethodInvocationReturnTypeTest.java @@ -121,4 +121,133 @@ void foo() { ) ); } + + @Test + void replaceVariableAssignmentWithGenericReturnType() { + rewriteRun( + spec -> spec.recipe(new ChangeMethodInvocationReturnType("bar.Bar bar()", "java.util.Set")) + .parser(JavaParser.fromJavaVersion() + //language=java + .dependsOn( + """ + package bar; + import java.util.List; + public class Bar { + public static List bar() { + return null; + } + } + """ + ) + ), + //language=java + java( + """ + import bar.Bar; + import java.util.List; + class Foo { + void foo() { + List one = Bar.bar(); + } + } + """, + """ + import bar.Bar; + + import java.util.Set; + + class Foo { + void foo() { + Set one = Bar.bar(); + } + } + """ + ) + ); + } + + @Test + void replaceVariableAssignmentWithNestedGenericReturnType() { + rewriteRun( + spec -> spec.recipe(new ChangeMethodInvocationReturnType("bar.Bar bar()", "java.util.Map>")) + .parser(JavaParser.fromJavaVersion() + //language=java + .dependsOn( + """ + package bar; + import java.util.List; + public class Bar { + public static List bar() { + return null; + } + } + """ + ) + ), + //language=java + java( + """ + import bar.Bar; + import java.util.List; + class Foo { + void foo() { + List one = Bar.bar(); + } + } + """, + """ + import bar.Bar; + import java.util.List; + import java.util.Map; + + class Foo { + void foo() { + Map> one = Bar.bar(); + } + } + """ + ) + ); + } + + @Test + void replaceParameterizedReturnTypeWithRaw() { + rewriteRun( + spec -> spec.recipe(new ChangeMethodInvocationReturnType("bar.Bar bar()", "java.lang.Object")) + .parser(JavaParser.fromJavaVersion() + //language=java + .dependsOn( + """ + package bar; + import java.util.List; + public class Bar { + public static List bar() { + return null; + } + } + """ + ) + ), + //language=java + java( + """ + import bar.Bar; + import java.util.List; + class Foo { + void foo() { + List one = Bar.bar(); + } + } + """, + """ + import bar.Bar; + class Foo { + void foo() { + Object one = Bar.bar(); + } + } + """ + ) + ); + } }