Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,17 @@

import lombok.EqualsAndHashCode;
import lombok.Value;
import org.jspecify.annotations.Nullable;
import org.openrewrite.*;
import org.openrewrite.internal.ListUtils;
import org.openrewrite.java.tree.J;
import org.openrewrite.java.tree.JavaType;
import org.openrewrite.java.tree.TypeUtils;
import org.openrewrite.java.tree.*;
import org.openrewrite.marker.Markers;

import java.util.ArrayList;
import java.util.List;

import static java.util.Collections.emptyList;
import static org.openrewrite.Tree.randomId;

@Value
@EqualsAndHashCode(callSuper = false)
Expand All @@ -36,7 +39,8 @@ public class ChangeMethodInvocationReturnType extends Recipe {
String methodPattern;

@Option(displayName = "New method invocation return type",
description = "The fully qualified new return type of method invocation.",
description = "The fully qualified new return type of method invocation. " +
"Parameterized types like `java.util.Set<java.lang.String>` are supported.",
example = "long")
String newReturnType;

Expand All @@ -61,7 +65,8 @@ public J.MethodInvocation visitMethodInvocation(J.MethodInvocation method, Execu
J.MethodInvocation m = super.visitMethodInvocation(method, ctx);
JavaType.Method type = m.getMethodType();
if (methodMatcher.matches(method) && type != null && !newReturnType.equals(type.getReturnType().toString())) {
type = type.withReturnType(JavaType.buildType(newReturnType));
JavaType newType = createTypeTree(newReturnType).getType();
type = type.withReturnType(newType);
m = m.withMethodType(type);
if (m.getName().getType() != null) {
m = m.withName(m.getName().withType(type));
Expand All @@ -74,27 +79,19 @@ public J.MethodInvocation visitMethodInvocation(J.MethodInvocation method, Execu
@Override
public J.VariableDeclarations visitVariableDeclarations(J.VariableDeclarations multiVariable, ExecutionContext ctx) {
methodUpdated = false;
JavaType.FullyQualified originalType = multiVariable.getTypeAsFullyQualified();
JavaType originalType = multiVariable.getType();
J.VariableDeclarations mv = super.visitVariableDeclarations(multiVariable, ctx);

if (methodUpdated) {
JavaType newType = JavaType.buildType(newReturnType);
JavaType.FullyQualified newFieldType = TypeUtils.asFullyQualified(newType);

maybeRemoveImport(originalType);
maybeAddImport(newFieldType);

mv = mv.withTypeExpression(mv.getTypeExpression() == null ?
null :
new J.Identifier(mv.getTypeExpression().getId(),
mv.getTypeExpression().getPrefix(),
Markers.EMPTY,
emptyList(),
newReturnType.substring(newReturnType.lastIndexOf('.') + 1),
newType,
null
)
);
TypeTree newTypeTree = createTypeTree(newReturnType);
JavaType newType = newTypeTree.getType();

removeImportsForType(originalType);
addImportsForType(newType);

if (mv.getTypeExpression() != null) {
mv = mv.withTypeExpression(newTypeTree.withPrefix(mv.getTypeExpression().getPrefix()));
}

mv = mv.withVariables(ListUtils.map(mv.getVariables(), var -> {
JavaType.FullyQualified varType = TypeUtils.asFullyQualified(var.getType());
Expand All @@ -107,6 +104,153 @@ public J.VariableDeclarations visitVariableDeclarations(J.VariableDeclarations m

return mv;
}

private void addImportsForType(@Nullable JavaType type) {
if (type instanceof JavaType.Parameterized) {
JavaType.Parameterized parameterized = (JavaType.Parameterized) type;
maybeAddImport(parameterized.getType());
for (JavaType param : parameterized.getTypeParameters()) {
addImportsForType(param);
}
} else if (type instanceof JavaType.Array) {
addImportsForType(((JavaType.Array) type).getElemType());
} else if (type instanceof JavaType.FullyQualified) {
maybeAddImport((JavaType.FullyQualified) type);
}
}

private void removeImportsForType(@Nullable JavaType type) {
if (type instanceof JavaType.Parameterized) {
JavaType.Parameterized parameterized = (JavaType.Parameterized) type;
maybeRemoveImport(parameterized.getType());
for (JavaType param : parameterized.getTypeParameters()) {
removeImportsForType(param);
}
} else if (type instanceof JavaType.Array) {
removeImportsForType(((JavaType.Array) type).getElemType());
} else if (type instanceof JavaType.FullyQualified) {
maybeRemoveImport((JavaType.FullyQualified) type);
}
}

private TypeTree createTypeTree(String typeName) {
int arrayIndex = typeName.lastIndexOf('[');
if (arrayIndex != -1) {
TypeTree elementType = createTypeTree(typeName.substring(0, arrayIndex));
return new J.ArrayType(
randomId(),
Space.EMPTY,
Markers.EMPTY,
elementType,
null,
JLeftPadded.build(Space.EMPTY),
new JavaType.Array(null, elementType.getType(), null)
);
}
int genericsIndex = typeName.indexOf('<');
if (genericsIndex != -1) {
TypeTree rawType = createTypeTree(typeName.substring(0, genericsIndex));
List<JRightPadded<Expression>> typeParameters = new ArrayList<>();
List<JavaType> typeParameterTypes = new ArrayList<>();
List<String> rawArgs = splitTypeArguments(typeName.substring(genericsIndex + 1, typeName.lastIndexOf('>')));
for (int i = 0; i < rawArgs.size(); i++) {
TypeTree paramTree = createTypeTree(rawArgs.get(i).trim());
if (i > 0) {
paramTree = paramTree.withPrefix(Space.SINGLE_SPACE);
}
typeParameters.add(JRightPadded.build((Expression) paramTree));
typeParameterTypes.add(paramTree.getType());
}
JavaType.FullyQualified rawFqn = TypeUtils.asFullyQualified(rawType.getType());
return new J.ParameterizedType(
randomId(),
Space.EMPTY,
Markers.EMPTY,
rawType,
JContainer.build(Space.EMPTY, typeParameters, Markers.EMPTY),
new JavaType.Parameterized(null, rawFqn, typeParameterTypes)
);
}
JavaType.Primitive primitive = JavaType.Primitive.fromKeyword(typeName);
if (primitive != null) {
return new J.Primitive(
randomId(),
Space.EMPTY,
Markers.EMPTY,
primitive
);
}
if ("?".equals(typeName)) {
return new J.Wildcard(
randomId(),
Space.EMPTY,
Markers.EMPTY,
null,
null
);
}
if (typeName.startsWith("?") && typeName.contains("extends")) {
return new J.Wildcard(
randomId(),
Space.EMPTY,
Markers.EMPTY,
new JLeftPadded<>(Space.SINGLE_SPACE, J.Wildcard.Bound.Extends, Markers.EMPTY),
createTypeTree(typeName.substring(typeName.indexOf("extends") + "extends".length() + 1).trim()).withPrefix(Space.SINGLE_SPACE)
);
}
if (typeName.startsWith("?") && typeName.contains("super")) {
return new J.Wildcard(
randomId(),
Space.EMPTY,
Markers.EMPTY,
new JLeftPadded<>(Space.SINGLE_SPACE, J.Wildcard.Bound.Super, Markers.EMPTY),
createTypeTree(typeName.substring(typeName.indexOf("super") + "super".length() + 1).trim()).withPrefix(Space.SINGLE_SPACE)
);
}
if (typeName.indexOf('.') == -1) {
String javaLangType = TypeUtils.findQualifiedJavaLangTypeName(typeName);
JavaType type = javaLangType != null ?
JavaType.buildType(javaLangType) :
JavaType.ShallowClass.build(typeName);
return new J.Identifier(
randomId(),
Space.EMPTY,
Markers.EMPTY,
emptyList(),
typeName,
type,
null
);
}
return new J.Identifier(
randomId(),
Space.EMPTY,
Markers.EMPTY,
emptyList(),
typeName.substring(typeName.lastIndexOf('.') + 1),
JavaType.ShallowClass.build(typeName),
null
);
}

private List<String> splitTypeArguments(String args) {
List<String> result = new ArrayList<>();
int depth = 0;
int start = 0;
for (int i = 0; i < args.length(); i++) {
char c = args.charAt(i);
if (c == '<') {
depth++;
} else if (c == '>') {
depth--;
} else if (c == ',' && depth == 0) {
result.add(args.substring(start, i));
start = i + 1;
}
}
result.add(args.substring(start));
return result;
}
};
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -121,4 +121,133 @@ void foo() {
)
);
}

@Test
void replaceVariableAssignmentWithGenericReturnType() {
rewriteRun(
spec -> spec.recipe(new ChangeMethodInvocationReturnType("bar.Bar bar()", "java.util.Set<java.lang.String>"))
.parser(JavaParser.fromJavaVersion()
//language=java
.dependsOn(
"""
package bar;
import java.util.List;
public class Bar {
public static List<String> bar() {
return null;
}
}
"""
)
),
//language=java
java(
"""
import bar.Bar;
import java.util.List;
class Foo {
void foo() {
List<String> one = Bar.bar();
}
}
""",
"""
import bar.Bar;

import java.util.Set;

class Foo {
void foo() {
Set<String> one = Bar.bar();
}
}
"""
)
);
}

@Test
void replaceVariableAssignmentWithNestedGenericReturnType() {
rewriteRun(
spec -> spec.recipe(new ChangeMethodInvocationReturnType("bar.Bar bar()", "java.util.Map<java.lang.String, java.util.List<java.lang.Integer>>"))
.parser(JavaParser.fromJavaVersion()
//language=java
.dependsOn(
"""
package bar;
import java.util.List;
public class Bar {
public static List<String> bar() {
return null;
}
}
"""
)
),
//language=java
java(
"""
import bar.Bar;
import java.util.List;
class Foo {
void foo() {
List<String> one = Bar.bar();
}
}
""",
"""
import bar.Bar;
import java.util.List;
import java.util.Map;

class Foo {
void foo() {
Map<String, List<Integer>> one = Bar.bar();
}
}
"""
)
);
}

@Test
void replaceParameterizedReturnTypeWithRaw() {
rewriteRun(
spec -> spec.recipe(new ChangeMethodInvocationReturnType("bar.Bar bar()", "java.lang.Object"))
.parser(JavaParser.fromJavaVersion()
//language=java
.dependsOn(
"""
package bar;
import java.util.List;
public class Bar {
public static List<String> bar() {
return null;
}
}
"""
)
),
//language=java
java(
"""
import bar.Bar;
import java.util.List;
class Foo {
void foo() {
List<String> one = Bar.bar();
}
}
""",
"""
import bar.Bar;
class Foo {
void foo() {
Object one = Bar.bar();
}
}
"""
)
);
}
}