Skip to content

Commit

Permalink
Support initialization of specific classes from struct literals (#4320)
Browse files Browse the repository at this point in the history
Add support for initializing types like `GenericClass(i32)` from a
struct literal. A new kind of instruction, `complete_type_witness`, is
added to the class definition to track the object representation type so
that it's visible to the generics machinery. Accesses to the object
representation of a class have all been updated to pass in the class's
`SpecificId` so that the types of the fields of the specific class are
used instead of the types of the fields of the generic class in places
that look at the object representation -- primarily class
initialization.
  • Loading branch information
zygoloid authored Sep 19, 2024
1 parent d3df61e commit 2044366
Show file tree
Hide file tree
Showing 214 changed files with 5,557 additions and 4,223 deletions.
14 changes: 10 additions & 4 deletions toolchain/check/context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include "toolchain/sem_ir/builtin_inst_kind.h"
#include "toolchain/sem_ir/file.h"
#include "toolchain/sem_ir/formatter.h"
#include "toolchain/sem_ir/generic.h"
#include "toolchain/sem_ir/ids.h"
#include "toolchain/sem_ir/import_ir.h"
#include "toolchain/sem_ir/inst.h"
Expand Down Expand Up @@ -873,7 +874,7 @@ class TypeCompleter {
if (inst.specific_id.is_valid()) {
ResolveSpecificDefinition(context_, inst.specific_id);
}
Push(class_info.object_repr_id);
Push(class_info.GetObjectRepr(context_.sem_ir(), inst.specific_id));
break;
}
case CARBON_KIND(SemIR::ConstType inst): {
Expand Down Expand Up @@ -1051,14 +1052,19 @@ class TypeCompleter {
// The value representation of an adapter is the value representation of
// its adapted type.
if (class_info.adapt_id.is_valid()) {
return GetNestedValueRepr(class_info.object_repr_id);
return GetNestedValueRepr(SemIR::GetTypeInSpecific(
context_.sem_ir(), inst.specific_id,
context_.insts()
.GetAs<SemIR::AdaptDecl>(class_info.adapt_id)
.adapted_type_id));
}
// Otherwise, the value representation for a class is a pointer to the
// object representation.
// TODO: Support customized value representations for classes.
// TODO: Pick a better value representation when possible.
return MakePointerValueRepr(class_info.object_repr_id,
SemIR::ValueRepr::ObjectAggregate);
return MakePointerValueRepr(
class_info.GetObjectRepr(context_.sem_ir(), inst.specific_id),
SemIR::ValueRepr::ObjectAggregate);
}

template <typename InstT>
Expand Down
35 changes: 21 additions & 14 deletions toolchain/check/convert.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -532,8 +532,8 @@ static auto ConvertStructToClass(Context& context, SemIR::StructType src_type,
SemIR::InstId value_id,
ConversionTarget target) -> SemIR::InstId {
PendingBlock target_block(context);
auto& class_info = context.classes().Get(dest_type.class_id);
if (class_info.inheritance_kind == SemIR::Class::Abstract) {
auto& dest_class_info = context.classes().Get(dest_type.class_id);
if (dest_class_info.inheritance_kind == SemIR::Class::Abstract) {
CARBON_DIAGNOSTIC(ConstructionOfAbstractClass, Error,
"Cannot construct instance of abstract class. "
"Consider using `partial {0}` instead.",
Expand All @@ -542,11 +542,13 @@ static auto ConvertStructToClass(Context& context, SemIR::StructType src_type,
target.type_id);
return SemIR::InstId::BuiltinError;
}
if (class_info.object_repr_id == SemIR::TypeId::Error) {
auto object_repr_id =
dest_class_info.GetObjectRepr(context.sem_ir(), dest_type.specific_id);
if (object_repr_id == SemIR::TypeId::Error) {
return SemIR::InstId::BuiltinError;
}
auto dest_struct_type =
context.types().GetAs<SemIR::StructType>(class_info.object_repr_id);
context.types().GetAs<SemIR::StructType>(object_repr_id);

// If we're trying to create a class value, form a temporary for the value to
// point to.
Expand All @@ -571,9 +573,10 @@ static auto ConvertStructToClass(Context& context, SemIR::StructType src_type,
return result_id;
}

// An inheritance path is a sequence of `BaseDecl`s in order from derived to
// base.
using InheritancePath = llvm::SmallVector<SemIR::InstId>;
// An inheritance path is a sequence of `BaseDecl`s and corresponding base types
// in order from derived to base.
using InheritancePath =
llvm::SmallVector<std::pair<SemIR::InstId, SemIR::TypeId>>;

// Computes the inheritance path from class `derived_id` to class `base_id`.
// Returns nullopt if `derived_id` is not a class derived from `base_id`.
Expand Down Expand Up @@ -602,10 +605,13 @@ static auto ComputeInheritancePath(Context& context, SemIR::TypeId derived_id,
result = std::nullopt;
break;
}
result->push_back(derived_class.base_id);
derived_id = context.insts()
.GetAs<SemIR::BaseDecl>(derived_class.base_id)
.base_type_id;
auto base_decl =
context.insts().GetAs<SemIR::BaseDecl>(derived_class.base_id);
auto base_type_id = SemIR::GetTypeInSpecific(
context.sem_ir(), derived_class_type->specific_id,
base_decl.base_type_id);
result->push_back({derived_class.base_id, base_type_id});
derived_id = base_type_id;
}
return result;
}
Expand All @@ -619,10 +625,10 @@ static auto ConvertDerivedToBase(Context& context, SemIR::LocId loc_id,
value_id = ConvertToValueOrRefExpr(context, value_id);

// Add a series of `.base` accesses.
for (auto base_id : path) {
for (auto [base_id, base_type_id] : path) {
auto base_decl = context.insts().GetAs<SemIR::BaseDecl>(base_id);
value_id = context.AddInst<SemIR::ClassElementAccess>(
loc_id, {.type_id = base_decl.base_type_id,
loc_id, {.type_id = base_type_id,
.base_id = value_id,
.index = base_decl.index});
}
Expand Down Expand Up @@ -677,7 +683,8 @@ static auto GetCompatibleBaseType(Context& context, SemIR::TypeId type_id)
if (auto class_type = context.types().TryGetAs<SemIR::ClassType>(type_id)) {
auto& class_info = context.classes().Get(class_type->class_id);
if (class_info.adapt_id.is_valid()) {
return class_info.object_repr_id;
return class_info.GetObjectRepr(context.sem_ir(),
class_type->specific_id);
}
}

Expand Down
3 changes: 3 additions & 0 deletions toolchain/check/eval.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1166,6 +1166,9 @@ auto TryEvalInstInContext(EvalContext& eval_context, SemIR::InstId inst_id,
case SemIR::ClassType::Kind:
return RebuildIfFieldsAreConstant(eval_context, inst,
&SemIR::ClassType::specific_id);
case SemIR::CompleteTypeWitness::Kind:
return RebuildIfFieldsAreConstant(
eval_context, inst, &SemIR::CompleteTypeWitness::object_repr_id);
case SemIR::FunctionType::Kind:
return RebuildIfFieldsAreConstant(eval_context, inst,
&SemIR::FunctionType::specific_id);
Expand Down
120 changes: 77 additions & 43 deletions toolchain/check/handle_class.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ static auto MergeClassRedecl(Context& context, SemIRLoc new_loc,
prev_class.body_block_id = new_class.body_block_id;
prev_class.adapt_id = new_class.adapt_id;
prev_class.base_id = new_class.base_id;
prev_class.object_repr_id = new_class.object_repr_id;
prev_class.complete_type_witness_id = new_class.complete_type_witness_id;
}

if ((prev_import_ir_id.is_valid() && !new_is_import) ||
Expand Down Expand Up @@ -561,55 +561,89 @@ auto HandleParseNode(Context& context, Parse::BaseDeclId node_id) -> bool {
return true;
}

auto HandleParseNode(Context& context, Parse::ClassDefinitionId /*node_id*/)
// Checks that the specified finished adapter definition is valid and builds and
// returns a corresponding complete type witness instruction.
static auto CheckCompleteAdapterClassType(Context& context,
Parse::NodeId node_id,
SemIR::ClassId class_id,
SemIR::InstBlockId fields_id)
-> SemIR::InstId {
const auto& class_info = context.classes().Get(class_id);
if (class_info.base_id.is_valid()) {
CARBON_DIAGNOSTIC(AdaptWithBase, Error,
"Adapter cannot have a base class.");
CARBON_DIAGNOSTIC(AdaptBaseHere, Note, "`base` declaration is here.");
context.emitter()
.Build(class_info.adapt_id, AdaptWithBase)
.Note(class_info.base_id, AdaptBaseHere)
.Emit();
return SemIR::InstId::BuiltinError;
}

if (!context.inst_blocks().Get(fields_id).empty()) {
auto first_field_id = context.inst_blocks().Get(fields_id).front();
CARBON_DIAGNOSTIC(AdaptWithFields, Error, "Adapter cannot have fields.");
CARBON_DIAGNOSTIC(AdaptFieldHere, Note, "First field declaration is here.");
context.emitter()
.Build(class_info.adapt_id, AdaptWithFields)
.Note(first_field_id, AdaptFieldHere)
.Emit();
return SemIR::InstId::BuiltinError;
}

// The object representation of the adapter is the object representation
// of the adapted type. This is the adapted type itself unless it's a class
// type.
//
// TODO: The object representation of `const T` should also be the object
// representation of `T`.
auto adapted_type_id = context.insts()
.GetAs<SemIR::AdaptDecl>(class_info.adapt_id)
.adapted_type_id;
if (auto adapted_class =
context.types().TryGetAs<SemIR::ClassType>(adapted_type_id)) {
auto& adapted_class_info = context.classes().Get(adapted_class->class_id);
if (adapted_class_info.adapt_id.is_valid()) {
return adapted_class_info.complete_type_witness_id;
}
}

return context.AddInst<SemIR::CompleteTypeWitness>(
node_id,
{.type_id = context.GetBuiltinType(SemIR::BuiltinInstKind::WitnessType),
.object_repr_id = adapted_type_id});
}

// Checks that the specified finished class definition is valid and builds and
// returns a corresponding complete type witness instruction.
static auto CheckCompleteClassType(Context& context, Parse::NodeId node_id,
SemIR::ClassId class_id,
SemIR::InstBlockId fields_id)
-> SemIR::InstId {
auto& class_info = context.classes().Get(class_id);
if (class_info.adapt_id.is_valid()) {
return CheckCompleteAdapterClassType(context, node_id, class_id, fields_id);
}

return context.AddInst<SemIR::CompleteTypeWitness>(
node_id,
{.type_id = context.GetBuiltinType(SemIR::BuiltinInstKind::WitnessType),
.object_repr_id = context.GetStructType(fields_id)});
}

auto HandleParseNode(Context& context, Parse::ClassDefinitionId node_id)
-> bool {
auto fields_id = context.args_type_info_stack().Pop();
auto class_id =
context.node_stack().Pop<Parse::NodeKind::ClassDefinitionStart>();
context.inst_block_stack().Pop();

// The class type is now fully defined. Compute its object representation.
auto complete_type_witness_id =
CheckCompleteClassType(context, node_id, class_id, fields_id);
auto& class_info = context.classes().Get(class_id);
if (class_info.adapt_id.is_valid()) {
class_info.object_repr_id = SemIR::TypeId::Error;
if (class_info.base_id.is_valid()) {
CARBON_DIAGNOSTIC(AdaptWithBase, Error,
"Adapter cannot have a base class.");
CARBON_DIAGNOSTIC(AdaptBaseHere, Note, "`base` declaration is here.");
context.emitter()
.Build(class_info.adapt_id, AdaptWithBase)
.Note(class_info.base_id, AdaptBaseHere)
.Emit();
} else if (!context.inst_blocks().Get(fields_id).empty()) {
auto first_field_id = context.inst_blocks().Get(fields_id).front();
CARBON_DIAGNOSTIC(AdaptWithFields, Error, "Adapter cannot have fields.");
CARBON_DIAGNOSTIC(AdaptFieldHere, Note,
"First field declaration is here.");
context.emitter()
.Build(class_info.adapt_id, AdaptWithFields)
.Note(first_field_id, AdaptFieldHere)
.Emit();
} else {
// The object representation of the adapter is the object representation
// of the adapted type.
auto adapted_type_id = context.insts()
.GetAs<SemIR::AdaptDecl>(class_info.adapt_id)
.adapted_type_id;
// If we adapt an adapter, directly track the non-adapter type we're
// adapting so that we have constant-time access to it.
if (auto adapted_class =
context.types().TryGetAs<SemIR::ClassType>(adapted_type_id)) {
auto& adapted_class_info =
context.classes().Get(adapted_class->class_id);
if (adapted_class_info.adapt_id.is_valid()) {
adapted_type_id = adapted_class_info.object_repr_id;
}
}
class_info.object_repr_id = adapted_type_id;
}
} else {
class_info.object_repr_id = context.GetStructType(fields_id);
}
class_info.complete_type_witness_id = complete_type_witness_id;

context.inst_block_stack().Pop();

FinishGenericDefinition(context, class_info.generic_id);

Expand Down
34 changes: 26 additions & 8 deletions toolchain/check/import_ref.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include "toolchain/sem_ir/import_ir.h"
#include "toolchain/sem_ir/inst.h"
#include "toolchain/sem_ir/inst_kind.h"
#include "toolchain/sem_ir/type_info.h"
#include "toolchain/sem_ir/typed_insts.h"

namespace Carbon::Check {
Expand Down Expand Up @@ -1011,6 +1012,9 @@ class ImportRefResolver {
case CARBON_KIND(SemIR::ClassType inst): {
return TryResolveTypedInst(inst);
}
case CARBON_KIND(SemIR::CompleteTypeWitness inst): {
return TryResolveTypedInst(inst);
}
case CARBON_KIND(SemIR::ConstType inst): {
return TryResolveTypedInst(inst);
}
Expand Down Expand Up @@ -1228,12 +1232,11 @@ class ImportRefResolver {
// Fills out the class definition for an incomplete class.
auto AddClassDefinition(const SemIR::Class& import_class,
SemIR::Class& new_class,
SemIR::ConstantId object_repr_const_id,
SemIR::InstId complete_type_witness_id,
SemIR::InstId base_id) -> void {
new_class.definition_id = new_class.first_owning_decl_id;

new_class.object_repr_id =
context_.GetTypeIdForTypeConstant(object_repr_const_id);
new_class.complete_type_witness_id = complete_type_witness_id;

new_class.scope_id = context_.name_scopes().Add(
new_class.first_owning_decl_id, SemIR::NameId::Invalid,
Expand Down Expand Up @@ -1312,10 +1315,10 @@ class ImportRefResolver {
auto param_const_ids = GetLocalParamConstantIds(import_class.param_refs_id);
auto generic_data = GetLocalGenericData(import_class.generic_id);
auto self_const_id = GetLocalConstantId(import_class.self_type_id);
auto object_repr_const_id =
import_class.object_repr_id.is_valid()
? GetLocalConstantId(import_class.object_repr_id)
: SemIR::ConstantId::Invalid;
auto complete_type_witness_id =
import_class.complete_type_witness_id.is_valid()
? GetLocalConstantInstId(import_class.complete_type_witness_id)
: SemIR::InstId::Invalid;
auto base_id = import_class.base_id.is_valid()
? GetLocalConstantInstId(import_class.base_id)
: SemIR::InstId::Invalid;
Expand All @@ -1334,7 +1337,7 @@ class ImportRefResolver {
new_class.self_type_id = context_.GetTypeIdForTypeConstant(self_const_id);

if (import_class.is_defined()) {
AddClassDefinition(import_class, new_class, object_repr_const_id,
AddClassDefinition(import_class, new_class, complete_type_witness_id,
base_id);
}

Expand Down Expand Up @@ -1368,6 +1371,21 @@ class ImportRefResolver {
}
}

auto TryResolveTypedInst(SemIR::CompleteTypeWitness inst) -> ResolveResult {
CARBON_CHECK(import_ir_.types().GetInstId(inst.type_id) ==
SemIR::InstId::BuiltinWitnessType);
auto object_repr_const_id = GetLocalConstantId(inst.object_repr_id);
if (HasNewWork()) {
return Retry();
}
auto object_repr_id =
context_.GetTypeIdForTypeConstant(object_repr_const_id);
return ResolveAs<SemIR::CompleteTypeWitness>(
{.type_id =
context_.GetBuiltinType(SemIR::BuiltinInstKind::WitnessType),
.object_repr_id = object_repr_id});
}

auto TryResolveTypedInst(SemIR::ConstType inst) -> ResolveResult {
CARBON_CHECK(inst.type_id == SemIR::TypeId::TypeType);
auto inner_const_id = GetLocalConstantId(inst.inner_id);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,9 @@ let d: c = {};
// CHECK:STDOUT: constants {
// CHECK:STDOUT: %C: type = class_type @C [template]
// CHECK:STDOUT: %.1: type = struct_type {} [template]
// CHECK:STDOUT: %.2: type = tuple_type () [template]
// CHECK:STDOUT: %.3: type = ptr_type %.1 [template]
// CHECK:STDOUT: %.2: <witness> = complete_type_witness %.1 [template]
// CHECK:STDOUT: %.3: type = tuple_type () [template]
// CHECK:STDOUT: %.4: type = ptr_type %.1 [template]
// CHECK:STDOUT: %struct: %C = struct_value () [template]
// CHECK:STDOUT: }
// CHECK:STDOUT:
Expand All @@ -43,6 +44,8 @@ let d: c = {};
// CHECK:STDOUT: }
// CHECK:STDOUT:
// CHECK:STDOUT: class @C {
// CHECK:STDOUT: %.loc11: <witness> = complete_type_witness %.1 [template = constants.%.2]
// CHECK:STDOUT:
// CHECK:STDOUT: !members:
// CHECK:STDOUT: .Self = constants.%C
// CHECK:STDOUT: }
Expand Down
Loading

0 comments on commit 2044366

Please sign in to comment.