Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
250 changes: 31 additions & 219 deletions nullability/inference/collect_evidence.cc
Original file line number Diff line number Diff line change
Expand Up @@ -86,10 +86,6 @@ using ::clang::dataflow::Formula;
using ::clang::dataflow::RecordInitListHelper;
using ::clang::dataflow::WatchedLiteralsSolver;

using ConcreteNullabilityCache =
absl::flat_hash_map<const Decl *,
std::optional<const PointerTypeNullability>>;

namespace {
/// Shared base class for visitors that walk the AST for evidence collection
/// purposes, to ensure they see the same nodes.
Expand Down Expand Up @@ -2672,48 +2668,6 @@ static void wrappedEmit(llvm::function_ref<EvidenceEmitter> Emit,
serializeLoc(SM, Loc).Loc));
}

/// Returns a function that the analysis can use to override Decl nullability
/// values from the source code being analyzed with previously inferred
/// nullabilities.
///
/// In practice, this should only override the default nullability for Decls
/// that do not spell out a nullability in source code, because we only pass in
/// inferences from the previous round which are non-trivial and annotations
/// "inferred" by reading an annotation from source code in the previous round
/// were marked trivial.
static auto getConcreteNullabilityOverrideFromPreviousInferences(
ConcreteNullabilityCache &Cache, USRCache &USRCache,
const PreviousInferences &PreviousInferences) {
return [&](const Decl &D) -> std::optional<const PointerTypeNullability *> {
auto [It, Inserted] = Cache.try_emplace(&D);
if (Inserted) {
std::optional<const Decl *> FingerprintedDecl;
Slot Slot;
if (auto *FD = dyn_cast<FunctionDecl>(&D)) {
FingerprintedDecl = FD;
Slot = SLOT_RETURN_TYPE;
} else if (auto *PD = dyn_cast<ParmVarDecl>(&D)) {
if (auto *Parent = dyn_cast_or_null<FunctionDecl>(
PD->getParentFunctionOrMethod())) {
FingerprintedDecl = Parent;
Slot = paramSlot(PD->getFunctionScopeIndex());
}
}
if (!FingerprintedDecl) return std::nullopt;
auto Fingerprint =
fingerprint(getOrGenerateUSR(USRCache, **FingerprintedDecl), Slot);
if (PreviousInferences.Nullable->contains(Fingerprint)) {
It->second.emplace(NullabilityKind::Nullable);
} else if (PreviousInferences.Nonnull->contains(Fingerprint)) {
It->second.emplace(NullabilityKind::NonNull);
} else {
It->second = std::nullopt;
}
}
if (!It->second) return std::nullopt;
return &*It->second;
};
}

template <typename ContainerT>
static bool hasAnyInferenceTargets(const ContainerT &Decls) {
Expand All @@ -2738,15 +2692,16 @@ using EBInitHandler =
llvm::function_ref<void(std::string_view USR, PointerNullState NullState,
const SerializedSrcLoc &Loc)>;

// If D is a constructor definition, summarizes cases of potential
// If D is a constructor definition, summarizes the exit block.
// From the summary, we can later potentially produce
// LEFT_NULLABLE_BY_CONSTRUCTOR evidence for smart pointer fields implicitly
// default-initialized and left nullable in the exit block of the constructor
// body.
static void processConstructorExitBlock(const clang::Decl &MaybeConstructor,
const Environment &ExitEnv,
USRCache &USRCache,
EBInitHandler InitHandler) {
auto *Ctor = dyn_cast<CXXConstructorDecl>(&MaybeConstructor);
// body. `EBSummary` is an out-parameter for accumulating results.
static void summarizeConstructorExitBlock(const clang::Decl& MaybeConstructor,
const Environment& ExitEnv,
USRCache& USRCache,
ExitBlockSummary& EBSummary) {
auto* Ctor = dyn_cast<CXXConstructorDecl>(&MaybeConstructor);
if (!Ctor) return;
for (auto *Initializer : Ctor->inits()) {
if (Initializer->isWritten() || Initializer->isInClassMemberInitializer()) {
Expand Down Expand Up @@ -2792,24 +2747,30 @@ static void processConstructorExitBlock(const clang::Decl &MaybeConstructor,
PointerNullState NullState = getPointerNullState(*PV);
auto &SM =
Field->getDeclContext()->getParentASTContext().getSourceManager();
std::string_view USR = getOrGenerateUSR(USRCache, *Field);

InitOnExitSummary& InitSummary = *EBSummary.add_ctor_inits_on_exit();
InitSummary.mutable_field()->set_usr(getOrGenerateUSR(USRCache, *Field));
*InitSummary.mutable_null_state() = savePointerNullState(NullState);
SerializedSrcLoc Loc = serializeLoc(
SM, Ctor->isImplicit() ? Field->getBeginLoc() : Ctor->getBeginLoc());
InitHandler(USR, NullState, Loc);
InitSummary.set_location(Loc.Loc);
}
}

// Supported late initializers are no-argument SetUp methods of classes that
// inherit from ::testing::Test. From the exit block of such a method, we
// collect LEFT_NOT_NULLABLE_BY_LATE_INITIALIZER evidence for smart pointer
// fields that are not nullable. This allows ignoring the
// LEFT_NULLABLE_BY_CONSTRUCTOR evidence for such a field.
static void processSupportedLateInitializerExitBlock(
const clang::Decl &MaybeLateInitializationMethod,
const Environment &ExitEnv, USRCache &USRCache, EBInitHandler InitHandler) {
auto *Method = dyn_cast<CXXMethodDecl>(&MaybeLateInitializationMethod);
// Summarize the exit block of a supported late initializer. From the
// summary, we can later potentially produce
// LEFT_NOT_NULLABLE_BY_LATE_INITIALIZER evidence, for smart pointer fields that
// are not nullable. This allows ignoring the LEFT_NULLABLE_BY_CONSTRUCTOR
// evidence for such a field. Supported late initializers are no-argument SetUp
// methods of classes that inherit from ::testing::Test. `EBSummary` is an
// out-parameter for accumulating results.
static void summarizeSupportedLateInitializerExitBlock(
const clang::Decl& MaybeLateInitializationMethod,
const Environment& ExitEnv, USRCache& USRCache,
ExitBlockSummary& EBSummary) {
auto* Method = dyn_cast<CXXMethodDecl>(&MaybeLateInitializationMethod);
if (!Method || !Method->isVirtual() || Method->getNumParams() != 0) return;
if (IdentifierInfo *Identifier = Method->getIdentifier();
if (IdentifierInfo* Identifier = Method->getIdentifier();
!Identifier || Identifier->getName() != "SetUp") {
return;
}
Expand Down Expand Up @@ -2838,71 +2799,18 @@ static void processSupportedLateInitializerExitBlock(
cast<dataflow::RecordStorageLocation>(ChildLoc), ExitEnv);
if (PV != nullptr && hasPointerNullState(*PV)) {
PointerNullState NullState = getPointerNullState(*PV);
std::string_view USR = getOrGenerateUSR(USRCache, *ChildDecl);
InitOnExitSummary& LateInitSummary = *EBSummary.add_late_inits_on_exit();
LateInitSummary.mutable_field()->set_usr(
getOrGenerateUSR(USRCache, *ChildDecl));
*LateInitSummary.mutable_null_state() = savePointerNullState(NullState);
SerializedSrcLoc Loc =
serializeLoc(Method->getParentASTContext().getSourceManager(),
Method->getBeginLoc());
InitHandler(USR, NullState, Loc);
LateInitSummary.set_location(Loc.Loc);
}
}
}

static void collectEvidenceFromConstructorExitBlock(
const clang::Decl &MaybeConstructor, const Environment &ExitEnv,
USRCache &USRCache, EvidenceCollector &Collector) {
processConstructorExitBlock(
MaybeConstructor, ExitEnv, USRCache,
[&Collector](std::string_view USR, PointerNullState NullState,
const SerializedSrcLoc &Loc) {
Collector.collectConstructorExitBlock(USR, NullState, Loc);
});
}

static void collectEvidenceFromSupportedLateInitializerExitBlock(
const clang::Decl &MaybeLateInitializationMethod,
const Environment &ExitEnv, USRCache &USRCache,
EvidenceCollector &Collector) {
processSupportedLateInitializerExitBlock(
MaybeLateInitializationMethod, ExitEnv, USRCache,
[&Collector](std::string_view USR, PointerNullState NullState,
const SerializedSrcLoc &Loc) {
Collector.collectSupportedLateInitializerExitBlock(USR, NullState, Loc);
});
}

// `EBSummary` is an out-parameter for accumulating results.
static void summarizeConstructorExitBlock(const clang::Decl &MaybeConstructor,
const Environment &ExitEnv,
USRCache &USRCache,
ExitBlockSummary &EBSummary) {
processConstructorExitBlock(
MaybeConstructor, ExitEnv, USRCache,
[&EBSummary](std::string_view USR, PointerNullState NullState,
const SerializedSrcLoc &Loc) {
InitOnExitSummary &InitSummary = *EBSummary.add_ctor_inits_on_exit();
InitSummary.mutable_field()->set_usr(USR);
*InitSummary.mutable_null_state() = savePointerNullState(NullState);
InitSummary.set_location(Loc.Loc);
});
}

// `EBSummary` is an out-parameter for accumulating results.
static void summarizeSupportedLateInitializerExitBlock(
const clang::Decl &MaybeLateInitializationMethod,
const Environment &ExitEnv, USRCache &USRCache,
ExitBlockSummary &EBSummary) {
processSupportedLateInitializerExitBlock(
MaybeLateInitializationMethod, ExitEnv, USRCache,
[&EBSummary](std::string_view USR, PointerNullState NullState,
const SerializedSrcLoc &Loc) {
InitOnExitSummary &LateInitSummary =
*EBSummary.add_late_inits_on_exit();
LateInitSummary.mutable_field()->set_usr(USR);
*LateInitSummary.mutable_null_state() = savePointerNullState(NullState);
LateInitSummary.set_location(Loc.Loc);
});
}

// Checks the "last layer" forwarding functions called from the given statement.
// This allows us to collect references made within forwarding functions, as if
// they were made directly by the statement. (skipping through the forwarding).
Expand Down Expand Up @@ -3129,102 +3037,6 @@ static std::vector<InferableSlot> gatherInferableSlots(
return InferableSlots;
}

llvm::Error collectEvidenceFromDefinition(
const Decl &Definition, llvm::function_ref<EvidenceEmitter> Emit,
USRCache &USRCache, const NullabilityPragmas &Pragmas,
const PreviousInferences &PreviousInferences,
const SolverFactory &MakeSolver) {
std::optional<DeclStmt> DeclStmtForVarDecl;
auto T = getTarget(Definition, DeclStmtForVarDecl);
if (!T) return T.takeError();
Stmt &TargetStmt = **T;

const auto *absl_nullable TargetFunc = dyn_cast<FunctionDecl>(&Definition);
dataflow::ReferencedDecls ReferencedDecls =
TargetFunc != nullptr ? dataflow::getReferencedDecls(*TargetFunc)
: dataflow::getReferencedDecls(TargetStmt);
collectReferencesFromForwardingFunctions(TargetStmt, ReferencedDecls);

// TODO: b/416755108 -- We should be able to check functions as
// well (and therefore drop the `!TargetFunc` filter), but we're missing some
// Referenced constructors, so `hasAnyInferenceTargets` will fail for certain
// functions.
if (!TargetFunc && !isInferenceTarget(Definition) &&
!hasAnyInferenceTargets(ReferencedDecls))
return llvm::Error::success();

ASTContext &Ctx = Definition.getASTContext();
llvm::Expected<dataflow::AdornedCFG> ACFG =
dataflow::AdornedCFG::build(Definition, TargetStmt, Ctx);
if (!ACFG) return ACFG.takeError();

std::unique_ptr<dataflow::Solver> Solver = MakeSolver();
DataflowAnalysisContext AnalysisContext(*Solver);
Environment Env = TargetFunc ? Environment(AnalysisContext, *TargetFunc)
: Environment(AnalysisContext, TargetStmt);
PointerNullabilityAnalysis Analysis(Ctx, Env, Pragmas);

std::unique_ptr<LocFilter> NotTestMainFileLocFilter =
getNotTestMainFileLocFilter(Ctx);

std::vector<InferableSlot> InferableSlots = gatherInferableSlots(
TypeNullabilityDefaults(Ctx, Pragmas), TargetFunc, TargetStmt,
ReferencedDecls, Analysis, AnalysisContext.arena(), USRCache,
NotTestMainFileLocFilter.get());

// Here, we overlay new knowledge from past iterations over the symbolic
// entities for the `InferableSlots` (whose symbols are invariant across
// inference iterations).
const auto &InferableSlotsConstraint = getConstraintsOnInferableSlots(
InferableSlots, PreviousInferences, AnalysisContext.arena());

ConcreteNullabilityCache ConcreteNullabilityCache;
Analysis.assignNullabilityOverride(
getConcreteNullabilityOverrideFromPreviousInferences(
ConcreteNullabilityCache, USRCache, PreviousInferences));

std::vector<
std::optional<dataflow::DataflowAnalysisState<PointerNullabilityLattice>>>
Results;
dataflow::CFGEltCallbacks<PointerNullabilityAnalysis> PostAnalysisCallbacks;
PostAnalysisCallbacks.Before =
[&](const CFGElement& Element,
const dataflow::DataflowAnalysisState<PointerNullabilityLattice>&
State) {
if (Solver->reachedLimit()) return;

EvidenceCollector Collector(InferableSlots, InferableSlotsConstraint,
State.Env, *Solver, Emit);
NullabilityBehaviorVisitor<EvidenceCollector>::visit(
InferableSlots, USRCache, State.Lattice, State.Env,
Ctx.getSourceManager(), Element, NotTestMainFileLocFilter.get(),
Collector);
};
if (llvm::Error Error = dataflow::runDataflowAnalysis(*ACFG, Analysis, Env,
PostAnalysisCallbacks)
.moveInto(Results))
return Error;

if (Results.empty()) return llvm::Error::success();

if (std::optional<dataflow::DataflowAnalysisState<PointerNullabilityLattice>>
&ExitBlockResult = Results[ACFG->getCFG().getExit().getBlockID()]) {
EvidenceCollector Collector(InferableSlots, InferableSlotsConstraint,
ExitBlockResult->Env, *Solver, Emit);
collectEvidenceFromConstructorExitBlock(Definition, ExitBlockResult->Env,
USRCache, Collector);
collectEvidenceFromSupportedLateInitializerExitBlock(
Definition, ExitBlockResult->Env, USRCache, Collector);
}

if (Solver->reachedLimit()) {
return llvm::createStringError(llvm::errc::interrupted,
"SAT solver reached iteration limit");
}

return llvm::Error::success();
}

static void summarizeFromTUIndexIfPresent(
const FunctionDecl* absl_nullable Func, const VirtualMethodIndex& TUIndex,
USRCache& Cache, VirtualMethodIndexSummary& Summary) {
Expand Down
32 changes: 12 additions & 20 deletions nullability/inference/collect_evidence.h
Original file line number Diff line number Diff line change
Expand Up @@ -158,8 +158,7 @@ struct PreviousInferences {
std::make_shared<const SortedFingerprintVector>();
};

/// Creates a solver with default parameters that is suitable for passing to
/// `collectEvidenceFromDefinition()`.
/// Creates a solver with default parameters that is suitable for inference.
std::unique_ptr<dataflow::Solver> makeDefaultSolverForInference();

/// Callback used to report collected nullability evidence.
Expand All @@ -178,30 +177,23 @@ llvm::unique_function<EvidenceEmitter> evidenceEmitterWithPropagation(
llvm::unique_function<EvidenceEmitter> Emit,
absl_nonnull std::shared_ptr<const VirtualMethodIndex> Index);

/// Analyze code (such as a function body or variable initializer) to infer
/// nullability.
///
/// Produces Evidence constraining the nullability slots of the symbols that
/// the code interacts with, such as the function's own parameters.
/// This is based on the code's behavior and our definition of null-safety.
///
/// Summarizes Nullability-relevant behaviors in and context for `Definition`
/// (which can be a function body or variable initializer). The summary can then
/// be used to collect evidence and infer nullability.
/// If std::nullopt is returned, the analysis succeeded, but there's no relevant
/// content.
/// It is up to the caller to ensure the definition is eligible for inference
/// (function has a body, is not dependent, etc).
llvm::Error collectEvidenceFromDefinition(
const Decl &, llvm::function_ref<EvidenceEmitter>, USRCache &USRCache,
const NullabilityPragmas &Pragmas,
const PreviousInferences &PreviousInferences = {},
const SolverFactory &MakeSolver = makeDefaultSolverForInference);

// Summarizes Nullability-relevant behaviors in and context for `Definition`.
// If std::nullopt is returned, the analysis succeeded, but there's no relevant
// content.
llvm::Expected<std::optional<CFGSummary>> summarizeDefinition(
const Decl& Definition, USRCache& USRCache,
const NullabilityPragmas& Pragmas,
const VirtualMethodIndex& VirtualMethodsInTU,
const SolverFactory& MakeSolver = makeDefaultSolverForInference);

/// Produces Evidence constraining the nullability slots of the symbols that
/// the code interacts with, such as the function's own parameters.
/// This is based on the code's behavior (which is summarized in the CFGSummary)
/// and our definition of null-safety.
llvm::Error collectEvidenceFromSummary(
const CFGSummary& Summary, llvm::function_ref<EvidenceEmitter> Emit,
const PreviousInferences& PreviousInferences,
Expand All @@ -210,7 +202,7 @@ llvm::Error collectEvidenceFromSummary(
/// Gathers evidence of a symbol's nullability from a declaration of it.
///
/// These are trivial "inferences" of what's already written in the code. e.g:
/// void foo(Nullable<int*>);
/// void foo(int* _Nullable);
/// The first parameter of foo must be nullable.
///
/// It is the caller's responsibility to ensure that the symbol is inferable.
Expand All @@ -227,7 +219,7 @@ struct EvidenceSites {
/// Definitions (e.g. function body, variable initializer) that can be
/// analyzed.
/// This will always be concrete code, not a template pattern. These may be
/// passed to collectEvidenceFromDefinition().
/// passed to summarizeDefinition().
llvm::DenseSet<const Decl *absl_nonnull> Definitions;

/// Find the evidence sites within the provided AST. If
Expand Down
Loading