diff --git a/rewrite-java-test/src/test/java/org/openrewrite/java/AddImportTest.java b/rewrite-java-test/src/test/java/org/openrewrite/java/AddImportTest.java index 68ba9620f70..f750d286937 100644 --- a/rewrite-java-test/src/test/java/org/openrewrite/java/AddImportTest.java +++ b/rewrite-java-test/src/test/java/org/openrewrite/java/AddImportTest.java @@ -186,6 +186,7 @@ class A { @Test void forceImportNoJavaRecord() { + // Add import for a class named `Record`, even within the same package, to avoid conflicts with java.lang.Record rewriteRun( spec -> spec.recipe(toRecipe(() -> new AddImport<>("com.acme.bank.Record", null, false))), //language=java @@ -211,6 +212,7 @@ class Foo { @Test void notForceImportJavaRecord() { + // Do not add import for java.lang.Record by default rewriteRun( spec -> spec.recipe(toRecipe(() -> new AddImport<>("java.lang.Record", null, false))), //language=java diff --git a/rewrite-java/src/main/java/org/openrewrite/java/AddImport.java b/rewrite-java/src/main/java/org/openrewrite/java/AddImport.java index c9908e66aaa..4299ef1dd08 100644 --- a/rewrite-java/src/main/java/org/openrewrite/java/AddImport.java +++ b/rewrite-java/src/main/java/org/openrewrite/java/AddImport.java @@ -105,10 +105,14 @@ public AddImport(@Nullable String packageName, String typeName, @Nullable String return cu; } - // No need to add imports if the class to import is in java.lang, or if the classes are within the same package - if (("java.lang".equals(packageName) && StringUtils.isBlank(member)) || - (!"Record".equals(typeName) && cu.getPackageDeclaration() != null && - packageName.equals(cu.getPackageDeclaration().getExpression().printTrimmed(getCursor())))) { + // No need to add imports if the class to import is in java.lang + if ("java.lang".equals(packageName) && StringUtils.isBlank(member)) { + return cu; + } + // Nor if the classes are within the same package + if (!"Record".equals(typeName) && // Record's late addition to `java.lang` might conflict with user class + cu.getPackageDeclaration() != null && + packageName.equals(cu.getPackageDeclaration().getExpression().printTrimmed(getCursor()))) { return cu; }