Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Convert rest of global typing pass group to mini passes #11717

Draft
wants to merge 7 commits into
base: develop
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
package org.enso.compiler.pass.lint;

import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import org.enso.compiler.context.InlineContext;
import org.enso.compiler.context.ModuleContext;
import org.enso.compiler.core.CompilerError;
import org.enso.compiler.core.IR;
import org.enso.compiler.core.ir.Expression;
import org.enso.compiler.core.ir.MetadataStorage;
import org.enso.compiler.core.ir.Name;
import org.enso.compiler.core.ir.Pattern;
import org.enso.compiler.core.ir.expression.Case;
import org.enso.compiler.core.ir.expression.Case.Branch;
import org.enso.compiler.core.ir.expression.warnings.Shadowed.PatternBinding;
import org.enso.compiler.pass.IRProcessingPass;
import org.enso.compiler.pass.MiniIRPass;
import org.enso.compiler.pass.MiniPassFactory;
import org.enso.compiler.pass.analyse.AliasAnalysis$;
import org.enso.compiler.pass.analyse.DataflowAnalysis$;
import org.enso.compiler.pass.analyse.DemandAnalysis$;
import org.enso.compiler.pass.analyse.TailCall;
import org.enso.compiler.pass.desugar.GenerateMethodBodies$;
import org.enso.compiler.pass.desugar.NestedPatternMatch$;
import org.enso.compiler.pass.resolve.IgnoredBindings$;
import scala.collection.immutable.List;
import scala.collection.immutable.Seq;
import scala.jdk.javaapi.CollectionConverters;

/**
* This pass detects and renames shadowed pattern fields.
*
* <p>This is necessary both in order to create a warning, but also to ensure that alias analysis
* doesn't get confused.
*
* <p>This pass requires no configuration.
*
* <p>This pass requires the context to provide:
*
* <p>- Nothing
*/
public final class ShadowedPatternFields implements MiniPassFactory {
public static final ShadowedPatternFields INSTANCE = new ShadowedPatternFields();

private ShadowedPatternFields() {}

@Override
public List<IRProcessingPass> precursorPasses() {
java.util.List<IRProcessingPass> list = java.util.List.of(GenerateMethodBodies$.MODULE$);
return CollectionConverters.asScala(list).toList();
}

@Override
public List<IRProcessingPass> invalidatedPasses() {
java.util.List<IRProcessingPass> list =
java.util.List.of(
AliasAnalysis$.MODULE$,
DataflowAnalysis$.MODULE$,
DemandAnalysis$.MODULE$,
IgnoredBindings$.MODULE$,
NestedPatternMatch$.MODULE$,
TailCall.INSTANCE);
return CollectionConverters.asScala(list).toList();
}

@Override
public MiniIRPass createForModuleCompilation(ModuleContext moduleContext) {
return new Mini();
}

@Override
public MiniIRPass createForInlineCompilation(InlineContext inlineContext) {
return new Mini();
}

private static final class Mini extends MiniIRPass {
@Override
@SuppressWarnings("unchecked")
public Expression transformExpression(Expression expr) {
return switch (expr) {
case Case.Branch branch -> lintCaseBranch(branch);
case Case.Expr caseExpr -> {
Seq<Branch> newBranches = caseExpr.branches().map(this::lintCaseBranch).toSeq();
yield caseExpr.copy(
caseExpr.scrutinee(),
newBranches,
caseExpr.isNested(),
caseExpr.location(),
caseExpr.passData(),
caseExpr.diagnostics(),
caseExpr.id());
}
default -> expr;
};
}

/**
* Lints for shadowed pattern variables in a case branch.
*
* @param branch the case branch to lint
* @return `branch`, with warnings for any shadowed pattern variables
*/
private Case.Branch lintCaseBranch(Case.Branch branch) {
var newPattern = lintPattern(branch.pattern());
return branch.copy(
newPattern,
branch.expression(),
branch.terminalBranch(),
branch.location(),
branch.passData(),
branch.diagnostics(),
branch.id());
}

/**
* Lints a pattern for shadowed pattern variables.
*
* <p>A later pattern variable shadows an earlier pattern variable with the same name.
*
* @param pattern the pattern to lint
* @return `pattern`, with a warning applied to any shadowed pattern variables
*/
private Pattern lintPattern(Pattern pattern) {
var seenNames = new HashSet<String>();
var lastSeen = new HashMap<String, IR>();

return go(pattern, seenNames, lastSeen);
}

private Pattern go(Pattern pattern, Set<String> seenNames, Map<String, IR> lastSeen) {
return switch (pattern) {
case Pattern.Name named -> {
var name = named.name().name();
if (seenNames.contains(name)) {
var warning = new PatternBinding(name, lastSeen.get(name), named.identifiedLocation());
lastSeen.put(name, named);
var blank = new Name.Blank(named.identifiedLocation(), new MetadataStorage());
var patternCopy = named.copyWithName(blank);
patternCopy.getDiagnostics().add(warning);
yield patternCopy;
} else if (!(named.name() instanceof Name.Blank)) {
lastSeen.put(name, named);
seenNames.add(name);
yield named;
} else {
yield named;
}
}
case Pattern.Constructor cons -> {
var newFields =
cons.fields().reverse().map(field -> go(field, seenNames, lastSeen)).reverse();
yield cons.copyWithFields(newFields);
}
case Pattern.Literal literal -> literal;
case Pattern.Type typed -> {
var name = typed.name().name();
if (seenNames.contains(name)) {
var warning = new PatternBinding(name, lastSeen.get(name), typed.identifiedLocation());
lastSeen.put(name, typed);
var blank = new Name.Blank(typed.identifiedLocation(), new MetadataStorage());
var typedCopy =
typed.copy(
blank,
typed.tpe(),
typed.location(),
typed.passData(),
typed.diagnostics(),
typed.id());
typedCopy.getDiagnostics().add(warning);
yield typedCopy;
} else if (!(typed.name() instanceof Name.Blank)) {
lastSeen.put(name, typed);
seenNames.add(name);
yield typed;
} else {
yield typed;
}
}
case Pattern.Documentation doc -> throw new CompilerError(
"Branch documentation should be desugared at an earlier stage.");
default -> pattern;
};
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
package org.enso.compiler.pass.optimise;

import java.util.ArrayList;
import java.util.stream.Stream;
import org.enso.compiler.context.InlineContext;
import org.enso.compiler.context.ModuleContext;
import org.enso.compiler.core.CompilerError;
import org.enso.compiler.core.ir.Expression;
import org.enso.compiler.core.ir.IdentifiedLocation;
import org.enso.compiler.core.ir.Pattern;
import org.enso.compiler.core.ir.expression.Case;
import org.enso.compiler.core.ir.expression.warnings.Unreachable;
import org.enso.compiler.pass.IRProcessingPass;
import org.enso.compiler.pass.MiniIRPass;
import org.enso.compiler.pass.MiniPassFactory;
import org.enso.compiler.pass.analyse.AliasAnalysis$;
import org.enso.compiler.pass.analyse.DataflowAnalysis$;
import org.enso.compiler.pass.analyse.DemandAnalysis$;
import org.enso.compiler.pass.analyse.TailCall;
import org.enso.compiler.pass.desugar.ComplexType$;
import org.enso.compiler.pass.desugar.FunctionBinding$;
import org.enso.compiler.pass.desugar.GenerateMethodBodies$;
import org.enso.compiler.pass.desugar.LambdaShorthandToLambda$;
import org.enso.compiler.pass.desugar.NestedPatternMatch$;
import org.enso.compiler.pass.resolve.DocumentationComments$;
import org.enso.compiler.pass.resolve.IgnoredBindings$;
import org.enso.scala.wrapper.ScalaConversions;
import scala.collection.immutable.List;
import scala.jdk.javaapi.CollectionConverters;

/**
* This pass discovers and optimizes away unreachable case branches.
*
* <p>It removes these unreachable expressions from the IR, and attaches a {@link
* org.enso.compiler.core.ir.Warning} diagnostic to the case expression itself.
*
* <p>Currently, a branch is considered 'unreachable' by this pass if:
*
* <ul>
* <li>It occurs after a catch-all branch.
* </ul>
*
* <p>In the future, this pass should be expanded to consider patterns that are entirely subsumed by
* previous patterns in its definition of unreachable, but this requires doing sophisticated
* coverage analysis, and hence should happen as part of the broader refactor of nested patterns
* desugaring.
*
* <p>This pass requires no configuration.
*
* <p>This pass requires the context to provide:
*
* <ul>
* <li>Nothing
* </ul>
*/
public final class UnreachableMatchBranches implements MiniPassFactory {
private UnreachableMatchBranches() {}

public static final UnreachableMatchBranches INSTANCE = new UnreachableMatchBranches();

@Override
public List<IRProcessingPass> precursorPasses() {
java.util.List<IRProcessingPass> passes = new ArrayList<>();
passes.add(ComplexType$.MODULE$);
passes.add(DocumentationComments$.MODULE$);
passes.add(FunctionBinding$.MODULE$);
passes.add(GenerateMethodBodies$.MODULE$);
passes.add(LambdaShorthandToLambda$.MODULE$);
return CollectionConverters.asScala(passes).toList();
}

@Override
public List<IRProcessingPass> invalidatedPasses() {
java.util.List<IRProcessingPass> passes = new ArrayList<>();
passes.add(AliasAnalysis$.MODULE$);
passes.add(DataflowAnalysis$.MODULE$);
passes.add(DemandAnalysis$.MODULE$);
passes.add(IgnoredBindings$.MODULE$);
passes.add(NestedPatternMatch$.MODULE$);
passes.add(TailCall.INSTANCE);
return CollectionConverters.asScala(passes).toList();
}

@Override
public MiniIRPass createForInlineCompilation(InlineContext inlineContext) {
return new Mini();
}

@Override
public MiniIRPass createForModuleCompilation(ModuleContext moduleContext) {
return new Mini();
}

private static class Mini extends MiniIRPass {
@Override
public Expression transformExpression(Expression expr) {
return switch (expr) {
case Case cse -> optimizeCase(cse);
default -> expr;
};
}

/**
* Optimizes a case expression by removing unreachable branches.
*
* <p>Additionally, it will attach a warning about unreachable branches to the case expression.
*
* @param cse the case expression to optimize
* @return `cse` with unreachable branches removed
*/
private Case optimizeCase(Case cse) {
if (cse instanceof Case.Expr expr) {
var branches = CollectionConverters.asJava(expr.branches());
var reachableNonCatchAllBranches =
branches.stream().takeWhile(branch -> !isCatchAll(branch));
var firstCatchAll = branches.stream().filter(this::isCatchAll).findFirst();
var unreachableBranches =
branches.stream().dropWhile(branch -> !isCatchAll(branch)).skip(1).toList();
List<Case.Branch> reachableBranches;
if (firstCatchAll.isPresent()) {
reachableBranches = appended(reachableNonCatchAllBranches, firstCatchAll.get());
} else {
reachableBranches = ScalaConversions.nil();
}

if (unreachableBranches.isEmpty()) {
return expr;
} else {
var firstUnreachableWithLoc =
unreachableBranches.stream()
.filter(branch -> branch.identifiedLocation() != null)
.findFirst();
var lastUnreachableWithLoc =
unreachableBranches.stream()
.filter(branch -> branch.identifiedLocation() != null)
.reduce((first, second) -> second);
IdentifiedLocation unreachableLocation = null;
if (firstUnreachableWithLoc.isPresent() && lastUnreachableWithLoc.isPresent()) {
unreachableLocation =
new IdentifiedLocation(
firstUnreachableWithLoc.get().location().get().start(),
lastUnreachableWithLoc.get().location().get().end(),
firstUnreachableWithLoc.get().id());
}

var diagnostic = new Unreachable.Branches(unreachableLocation);
var copiedExpr = expr.copyWithBranches(reachableBranches);
copiedExpr.getDiagnostics().add(diagnostic);
return copiedExpr;
}
} else {
throw new CompilerError("Unexpected case branch.");
}
}

/**
* Determines if a branch is a catch all branch.
*
* @param branch the branch to check
* @return `true` if `branch` is catch-all, otherwise `false`
*/
private boolean isCatchAll(Case.Branch branch) {
return switch (branch.pattern()) {
case Pattern.Name ignored -> true;
default -> false;
};
}

private static List<Case.Branch> appended(Stream<Case.Branch> branches, Case.Branch branch) {
var ret = new ArrayList<>(branches.toList());
ret.add(branch);
return CollectionConverters.asScala(ret).toList();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,8 @@ class Passes(config: CompilerConfig) {
)
} else List())
++ List(
ShadowedPatternFields,
UnreachableMatchBranches,
ShadowedPatternFields.INSTANCE,
UnreachableMatchBranches.INSTANCE,
NestedPatternMatch,
IgnoredBindings,
TypeFunctions,
Expand Down
Loading
Loading