Skip to content

Commit

Permalink
Facet types support rewrite (where .A =...) constraints (#4613)
Browse files Browse the repository at this point in the history
* Rewrite constraints are stored in a facet type, substituted, imported,
and formatted.
* We now distinguish `.Self` from other symbolic bindings in two ways:
* `.Self` itself now has an invalid compile time binding index (since it
doesn't bind to any of the generic parameters). As a result, we no
longer need to create a generic region in `handle_where.cpp`.
* There is a new phase tracking values that are only symbolic because
they transitively depend on `.Self`. This allows us to give the result
of a `where` expression template phase as long as it doesn't use any
symbolic constants other than `.Self` or other designators.
* `AddConstant` has been removed from `check/context` since it was only
used from `eval`. This meant less plumbing of the phase change.
* Evaluation of `BindSymbolicName` now also performs substitution into
its type.
* Include a bit more information in some diagnostics.
* `StringifyTypeExpr` outputs rewrites, which required adding support
for associated entities as well.
  * Associated entities now have an entity name set when importing.
* Adds tests for some interesting cases with rewrites and uses of
`.Self` mixed with other symbolic constants.

Still to do:
* There is no validation that any particular type satisfies rewrite
constraints.
  * Access to members of a facet type do not see the rewritten values.
* Impls don't recognize whether associated constants have rewrites
setting their values.
  * No support for resolving facet types.

---------

Co-authored-by: Josh L <[email protected]>
Co-authored-by: Richard Smith <[email protected]>
  • Loading branch information
3 people authored Dec 3, 2024
1 parent dc5edb8 commit 33110d0
Show file tree
Hide file tree
Showing 44 changed files with 3,113 additions and 1,446 deletions.
13 changes: 7 additions & 6 deletions toolchain/check/context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,8 @@ auto Context::FinishInst(SemIR::InstId inst_id, SemIR::Inst inst) -> void {

// If the instruction has a symbolic constant type, track that we need to
// substitute into it.
if (types().GetConstantId(inst.type_id()).is_symbolic()) {
if (constant_values().DependsOnGenericParameter(
types().GetConstantId(inst.type_id()))) {
dep_kind |= GenericRegionStack::DependencyKind::SymbolicType;
}

Expand All @@ -128,7 +129,7 @@ auto Context::FinishInst(SemIR::InstId inst_id, SemIR::Inst inst) -> void {

// If the constant value is symbolic, track that we need to substitute into
// it.
if (const_id.is_symbolic()) {
if (constant_values().DependsOnGenericParameter(const_id)) {
dep_kind |= GenericRegionStack::DependencyKind::SymbolicConstant;
}
}
Expand Down Expand Up @@ -1281,7 +1282,7 @@ auto Context::TryToDefineType(SemIR::TypeId type_id,
ResolveSpecificDefinition(*this, interface.specific_id);
}
}
// TODO: Process other requirements.
// TODO: Finish facet type resolution.
}

return true;
Expand All @@ -1304,9 +1305,9 @@ auto Context::GetTypeIdForTypeConstant(SemIR::ConstantId constant_id)
auto Context::FacetTypeFromInterface(SemIR::InterfaceId interface_id,
SemIR::SpecificId specific_id)
-> SemIR::FacetType {
SemIR::FacetTypeId facet_type_id = facet_types().Add(SemIR::FacetTypeInfo{
.impls_constraints = {{interface_id, specific_id}},
.requirement_block_id = SemIR::InstBlockId::Invalid});
SemIR::FacetTypeId facet_type_id = facet_types().Add(
SemIR::FacetTypeInfo{.impls_constraints = {{interface_id, specific_id}},
.other_requirements = false});
return {.type_id = SemIR::TypeId::TypeType, .facet_type_id = facet_type_id};
}

Expand Down
7 changes: 0 additions & 7 deletions toolchain/check/context.h
Original file line number Diff line number Diff line change
Expand Up @@ -162,13 +162,6 @@ class Context {
return AddPatternInst(SemIR::LocIdAndInst(node_id, inst));
}

// Adds an instruction to the constants block, returning the produced ID.
auto AddConstant(SemIR::Inst inst, bool is_symbolic) -> SemIR::ConstantId {
auto const_id = constants().GetOrAdd(inst, is_symbolic);
CARBON_VLOG("AddConstant: {0}\n", inst);
return const_id;
}

// Pushes a parse tree node onto the stack, storing the SemIR::Inst as the
// result.
template <typename InstT>
Expand Down
112 changes: 88 additions & 24 deletions toolchain/check/eval.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,9 @@ namespace {
enum class Phase : uint8_t {
// Value could be entirely and concretely computed.
Template,
// Evaluation phase is symbolic because the expression involves specifically a
// reference to `.Self`.
PeriodSelfSymbolic,
// Evaluation phase is symbolic because the expression involves a reference to
// a symbolic binding.
Symbolic,
Expand All @@ -191,16 +194,20 @@ enum class Phase : uint8_t {
} // namespace

// Gets the phase in which the value of a constant will become available.
static auto GetPhase(SemIR::ConstantId constant_id) -> Phase {
static auto GetPhase(EvalContext& eval_context, SemIR::ConstantId constant_id)
-> Phase {
if (!constant_id.is_constant()) {
return Phase::Runtime;
} else if (constant_id == SemIR::ConstantId::Error) {
return Phase::UnknownDueToError;
} else if (constant_id.is_template()) {
return Phase::Template;
} else if (eval_context.constant_values().DependsOnGenericParameter(
constant_id)) {
return Phase::Symbolic;
} else {
CARBON_CHECK(constant_id.is_symbolic());
return Phase::Symbolic;
return Phase::PeriodSelfSymbolic;
}
}

Expand All @@ -210,14 +217,34 @@ static auto LatestPhase(Phase a, Phase b) -> Phase {
std::max(static_cast<uint8_t>(a), static_cast<uint8_t>(b)));
}

// `where` expressions using `.Self` should not be considered symbolic
// - `Interface where .Self impls I and .A = bool` -> template
// - `T:! type` ... `Interface where .A = T` -> symbolic, since uses `T` which
// is symbolic and not due to `.Self`.
static auto UpdatePhaseIgnorePeriodSelf(EvalContext& eval_context,
SemIR::ConstantId constant_id,
Phase* phase) {
Phase constant_phase = GetPhase(eval_context, constant_id);
// Since LatestPhase(x, Phase::Template) == x, this is equivalent to replacing
// Phase::PeriodSelfSymbolic with Phase::Template.
if (constant_phase != Phase::PeriodSelfSymbolic) {
*phase = LatestPhase(*phase, constant_phase);
}
}

// Forms a `constant_id` describing a given evaluation result.
static auto MakeConstantResult(Context& context, SemIR::Inst inst, Phase phase)
-> SemIR::ConstantId {
switch (phase) {
case Phase::Template:
return context.AddConstant(inst, /*is_symbolic=*/false);
return context.constants().GetOrAdd(inst,
SemIR::ConstantStore::IsTemplate);
case Phase::PeriodSelfSymbolic:
return context.constants().GetOrAdd(
inst, SemIR::ConstantStore::IsPeriodSelfSymbolic);
case Phase::Symbolic:
return context.AddConstant(inst, /*is_symbolic=*/true);
return context.constants().GetOrAdd(inst,
SemIR::ConstantStore::IsSymbolic);
case Phase::UnknownDueToError:
return SemIR::ConstantId::Error;
case Phase::Runtime:
Expand Down Expand Up @@ -270,7 +297,7 @@ static auto MakeFloatResult(Context& context, SemIR::TypeId type_id,
static auto GetConstantValue(EvalContext& eval_context, SemIR::InstId inst_id,
Phase* phase) -> SemIR::InstId {
auto const_id = eval_context.GetConstantValue(inst_id);
*phase = LatestPhase(*phase, GetPhase(const_id));
*phase = LatestPhase(*phase, GetPhase(eval_context, const_id));
return eval_context.constant_values().GetInstId(const_id);
}

Expand All @@ -279,7 +306,7 @@ static auto GetConstantValue(EvalContext& eval_context, SemIR::InstId inst_id,
static auto GetConstantValue(EvalContext& eval_context, SemIR::TypeId type_id,
Phase* phase) -> SemIR::TypeId {
auto const_id = eval_context.GetConstantValue(type_id);
*phase = LatestPhase(*phase, GetPhase(const_id));
*phase = LatestPhase(*phase, GetPhase(eval_context, const_id));
return eval_context.context().GetTypeIdForTypeConstant(const_id);
}

Expand Down Expand Up @@ -395,7 +422,7 @@ static auto GetConstantValue(EvalContext& eval_context,
}

// Like `GetConstantValue` but does a `FacetTypeId` -> `FacetTypeInfo`
// conversion.
// conversion. Does not perform canonicalization.
static auto GetConstantFacetTypeInfo(EvalContext& eval_context,
SemIR::FacetTypeId facet_type_id,
Phase* phase) -> SemIR::FacetTypeInfo {
Expand All @@ -404,8 +431,14 @@ static auto GetConstantFacetTypeInfo(EvalContext& eval_context,
interface.specific_id =
GetConstantValue(eval_context, interface.specific_id, phase);
}
std::sort(info.impls_constraints.begin(), info.impls_constraints.end());
// TODO: Process & canonicalize other requirements.
for (auto& rewrite : info.rewrite_constraints) {
rewrite.lhs_const_id = eval_context.GetInContext(rewrite.lhs_const_id);
rewrite.rhs_const_id = eval_context.GetInContext(rewrite.rhs_const_id);
// `where` requirements using `.Self` should not be considered symbolic
UpdatePhaseIgnorePeriodSelf(eval_context, rewrite.lhs_const_id, phase);
UpdatePhaseIgnorePeriodSelf(eval_context, rewrite.rhs_const_id, phase);
}
// TODO: Process other requirements.
return info;
}

Expand Down Expand Up @@ -524,7 +557,8 @@ static auto PerformAggregateAccess(EvalContext& eval_context, SemIR::Inst inst)
return eval_context.GetConstantValue(elements[index]);
} else {
CARBON_CHECK(phase != Phase::Template,
"Failed to evaluate template constant {0}", inst);
"Failed to evaluate template constant {0} arg0: {1}", inst,
eval_context.insts().Get(access_inst.aggregate_id));
}
return MakeConstantResult(eval_context.context(), access_inst, phase);
}
Expand Down Expand Up @@ -1461,6 +1495,7 @@ static auto TryEvalInstInContext(EvalContext& eval_context,
Phase phase = Phase::Template;
SemIR::FacetTypeInfo info = GetConstantFacetTypeInfo(
eval_context, facet_type.facet_type_id, &phase);
info.Canonicalize();
// TODO: Reuse `inst` if we can detect that nothing has changed.
return MakeFacetTypeResult(eval_context.context(), info, phase);
}
Expand Down Expand Up @@ -1570,21 +1605,30 @@ static auto TryEvalInstInContext(EvalContext& eval_context,
const auto& bind_name =
eval_context.entity_names().Get(bind.entity_name_id);

// If we know which specific we're evaluating within and this is an
// argument of that specific, its constant value is the corresponding
// argument value.
if (auto value =
eval_context.GetCompileTimeBindValue(bind_name.bind_index);
value.is_valid()) {
return value;
Phase phase;
if (bind_name.name_id == SemIR::NameId::PeriodSelf) {
phase = Phase::PeriodSelfSymbolic;
} else {
// If we know which specific we're evaluating within and this is an
// argument of that specific, its constant value is the corresponding
// argument value.
if (auto value =
eval_context.GetCompileTimeBindValue(bind_name.bind_index);
value.is_valid()) {
return value;
}
phase = Phase::Symbolic;
}

// The constant form of a symbolic binding is an idealized form of the
// original, with no equivalent value.
bind.entity_name_id =
eval_context.entity_names().MakeCanonical(bind.entity_name_id);
bind.value_id = SemIR::InstId::Invalid;
return MakeConstantResult(eval_context.context(), bind, Phase::Symbolic);
if (!ReplaceFieldWithConstantValue(
eval_context, &bind, &SemIR::BindSymbolicName::type_id, &phase)) {
return MakeNonConstantResult(phase);
}
return MakeConstantResult(eval_context.context(), bind, phase);
}

// These semantic wrappers don't change the constant value.
Expand Down Expand Up @@ -1652,8 +1696,7 @@ static auto TryEvalInstInContext(EvalContext& eval_context,
eval_context.insts().Get(typed_inst.period_self_id).type_id();
SemIR::Inst base_facet_inst =
eval_context.GetConstantValueAsInst(base_facet_type_id);
SemIR::FacetTypeInfo info = {.requirement_block_id =
SemIR::InstBlockId::Invalid};
SemIR::FacetTypeInfo info = {.other_requirements = false};
// `where` provides that the base facet is an error, `type`, or a facet
// type.
if (auto facet_type = base_facet_inst.TryAs<SemIR::FacetType>()) {
Expand All @@ -1666,16 +1709,37 @@ static auto TryEvalInstInContext(EvalContext& eval_context,
"Unexpected type_id: {0}, inst: {1}", base_facet_type_id,
base_facet_inst);
}
// TODO: Combine other requirements, and then process & canonicalize them.
info.requirement_block_id = typed_inst.requirements_id;
if (typed_inst.requirements_id.is_valid()) {
auto insts = eval_context.inst_blocks().Get(typed_inst.requirements_id);
for (auto inst_id : insts) {
if (auto rewrite =
eval_context.insts().TryGetAs<SemIR::RequirementRewrite>(
inst_id)) {
SemIR::ConstantId lhs =
eval_context.GetConstantValue(rewrite->lhs_id);
SemIR::ConstantId rhs =
eval_context.GetConstantValue(rewrite->rhs_id);
// `where` requirements using `.Self` should not be considered
// symbolic
UpdatePhaseIgnorePeriodSelf(eval_context, lhs, &phase);
UpdatePhaseIgnorePeriodSelf(eval_context, rhs, &phase);
info.rewrite_constraints.push_back(
{.lhs_const_id = lhs, .rhs_const_id = rhs});
} else {
// TODO: Handle other requirements
info.other_requirements = true;
}
}
}
info.Canonicalize();
return MakeFacetTypeResult(eval_context.context(), info, phase);
}

// `not true` -> `false`, `not false` -> `true`.
// All other uses of unary `not` are non-constant.
case CARBON_KIND(SemIR::UnaryOperatorNot typed_inst): {
auto const_id = eval_context.GetConstantValue(typed_inst.operand_id);
auto phase = GetPhase(const_id);
auto phase = GetPhase(eval_context, const_id);
if (phase == Phase::Template) {
auto value = eval_context.insts().GetAs<SemIR::BoolLiteral>(
eval_context.constant_values().GetInstId(const_id));
Expand Down
10 changes: 7 additions & 3 deletions toolchain/check/generic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ class RebuildGenericConstantInEvalBlockCallbacks final
context_.insts().Get(inst_id));
return true;
}
if (!const_id.is_symbolic()) {
if (!context_.constant_values().DependsOnGenericParameter(const_id)) {
// This instruction doesn't have a symbolic constant value, so can't
// contain any bindings that need to be substituted.
return true;
Expand All @@ -104,8 +104,12 @@ class RebuildGenericConstantInEvalBlockCallbacks final
// block.
if (auto binding =
context_.insts().TryGetAs<SemIR::BindSymbolicName>(inst_id)) {
inst_id = Rebuild(inst_id, *binding);
return true;
if (context_.entity_names()
.Get(binding->entity_name_id)
.bind_index.is_valid()) {
inst_id = Rebuild(inst_id, *binding);
return true;
}
}

if (auto pattern =
Expand Down
12 changes: 3 additions & 9 deletions toolchain/check/handle_where.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,22 +27,19 @@ auto HandleParseNode(Context& context, Parse::WhereOperandId node_id) -> bool {
// Introduce a name scope so that we can remove the `.Self` entry we are
// adding to name lookup at the end of the `where` expression.
context.scope_stack().Push();
// Create a generic region containing `.Self` and the constraints.
StartGenericDecl(context);
// Introduce `.Self` as a symbolic binding. Its type is the value of the
// expression to the left of `where`, so `MyInterface` in the example above.
// Because there is no equivalent non-symbolic value, we use `Invalid` as
// the `value_id` on the `BindSymbolicName`.
auto entity_name_id = context.entity_names().Add(
{.name_id = SemIR::NameId::PeriodSelf,
.parent_scope_id = context.scope_stack().PeekNameScopeId(),
.bind_index = context.scope_stack().AddCompileTimeBinding()});
// Invalid because this is not the parameter of a generic.
.bind_index = SemIR::CompileTimeBindIndex::Invalid});
auto inst_id =
context.AddInst(SemIR::LocIdAndInst::NoLoc<SemIR::BindSymbolicName>(
{.type_id = self_type_id,
.entity_name_id = entity_name_id,
// Invalid because there is no equivalent non-symbolic value.
.value_id = SemIR::InstId::Invalid}));
context.scope_stack().PushCompileTimeBinding(inst_id);
auto existing =
context.scope_stack().LookupOrAddName(SemIR::NameId::PeriodSelf, inst_id);
// Shouldn't have any names in newly created scope.
Expand Down Expand Up @@ -122,9 +119,6 @@ auto HandleParseNode(Context& /*context*/, Parse::RequirementAndId /*node_id*/)
}

auto HandleParseNode(Context& context, Parse::WhereExprId node_id) -> bool {
// Discard the generic region containing `.Self` and the constraints.
// TODO: Decide if we want to build a `Generic` object for this.
DiscardGenericDecl(context);
// Remove `PeriodSelf` from name lookup, undoing the `Push` done for the
// `WhereOperand`.
context.scope_stack().Pop();
Expand Down
9 changes: 6 additions & 3 deletions toolchain/check/impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -200,11 +200,14 @@ static auto BuildInterfaceWitness(
}
break;
}
case SemIR::AssociatedConstantDecl::Kind:
case CARBON_KIND(SemIR::AssociatedConstantDecl associated): {
// TODO: Check we have a value for this constant in the constraint.
context.TODO(impl.definition_id,
"impl of interface with associated constant");
context.TODO(
impl.definition_id,
"impl of interface with associated constant " +
context.names().GetFormatted(associated.name_id).str());
return SemIR::InstId::BuiltinErrorInst;
}
default:
CARBON_CHECK(decl_id == SemIR::InstId::BuiltinErrorInst,
"Unexpected kind of associated entity {0}", decl);
Expand Down
Loading

0 comments on commit 33110d0

Please sign in to comment.