Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: prevent importing classes which are already imported via * #4320

Merged
merged 5 commits into from
Dec 6, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 21 additions & 3 deletions src/main/java/spoon/support/compiler/jdt/JDTImportBuilder.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,10 @@
import spoon.reflect.factory.Factory;
import spoon.reflect.reference.CtReference;

import java.util.Arrays;
import java.util.HashSet;
import java.util.List;
import java.util.Optional;
import java.util.Set;

/**
Expand Down Expand Up @@ -59,9 +61,8 @@ void build() {
int lastDot = importName.lastIndexOf('.');
String packageName = importName.substring(0, lastDot);

// only get package from the model by traversing from rootPackage the model
// it does not use reflection to achieve that
CtPackage ctPackage = this.factory.Package().get(packageName);
// load package by looking up in the class loader or in the model being built
CtPackage ctPackage = loadPackage(packageName);

if (ctPackage != null) {
this.imports.add(createImportWithPosition(ctPackage.getReference(), importRef));
Expand Down Expand Up @@ -138,6 +139,23 @@ private CtImport createUnresolvedImportWithPosition(String ref, boolean isStatic
return imprt;
}

private CtPackage loadPackage(String packageName) {
// get all packages known for the current class loader and the ones which are accessible from it
Package[] allPackagesInAllClassLoaders = Package.getPackages();

Optional<Package> requiredPackage = Arrays.stream(allPackagesInAllClassLoaders)
.filter(pkg -> pkg.getName().equals(packageName))
.findAny();
if (requiredPackage.isPresent()) {
CtPackage ctPackage = factory.createPackage();
ctPackage.setSimpleName(requiredPackage.get().getName());
return ctPackage;
}

// get package by traversing the model
return factory.Package().get(packageName);
}

private CtType getOrLoadClass(String className) {
CtType klass = this.factory.Type().get(className);

Expand Down
12 changes: 9 additions & 3 deletions src/test/java/spoon/reflect/visitor/ImportCleanerTest.java
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
package spoon.reflect.visitor;

import org.junit.Test;
import org.junit.jupiter.api.Test;
import spoon.Launcher;
import spoon.reflect.CtModel;
import spoon.reflect.declaration.CtCompilationUnit;
Expand All @@ -16,13 +16,19 @@
public class ImportCleanerTest {

@Test
public void testDoesNotDuplicateUnresolvedImports() {
void testDoesNotImportClassesIfAlreadyImportedViaWildCard() {
// contract: The import cleaner should not import classes if they are encompassed in wildcard import.
testImportCleanerDoesNotAlterImports("src/test/resources/importCleaner/WildCardImport.java", "WildCardImport");
}

@Test
void testDoesNotDuplicateUnresolvedImports() {
// contract: The import cleaner should not duplicate unresolved imports
testImportCleanerDoesNotAlterImports("./src/test/resources/unresolved/UnresolvedImport.java", "UnresolvedImport");
}

@Test
public void testDoesNotImportInheritedStaticMethod() {
void testDoesNotImportInheritedStaticMethod() {
// contract: The import cleaner should not import static attributes that are inherited
testImportCleanerDoesNotAlterImports("./src/test/resources/inherit-static-method", "Derived");
}
Expand Down
5 changes: 5 additions & 0 deletions src/test/resources/importCleaner/WildCardImport.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
import java.util.*;

public class WildCardImport {
private static List<Integer> x = new ArrayList<>();
}