Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
266 changes: 238 additions & 28 deletions crates/wit-component/src/encoding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,32 @@ fn to_val_type(ty: &WasmType) -> ValType {
}
}

fn import_func_name(f: &Function) -> String {
match f.kind {
FunctionKind::Freestanding | FunctionKind::AsyncFreestanding => {
format!("import-func-{}", f.item_name())
}

// transform `[method]foo.bar` into `import-method-foo-bar` to
// have it be a valid kebab-name which can't conflict with
// anything else.
//
// There's probably a better and more "formal" way to do this
// but quick-and-dirty string manipulation should work well
// enough for now hopefully.
FunctionKind::Method(_)
| FunctionKind::AsyncMethod(_)
| FunctionKind::Static(_)
| FunctionKind::AsyncStatic(_)
| FunctionKind::Constructor(_) => {
format!(
"import-{}",
f.name.replace('[', "").replace([']', '.', ' '], "-")
)
}
}
}

bitflags::bitflags! {
/// Options in the `canon lower` or `canon lift` required for a particular
/// function.
Expand Down Expand Up @@ -394,6 +420,10 @@ pub struct EncodingState<'a> {

/// Metadata about the world inferred from the input to `ComponentEncoder`.
info: &'a ComponentWorld<'a>,

/// Maps from original export name to task initialization wrapper function index.
/// Used to wrap exports with __wasm_init_(async_)task calls.
export_task_initialization_wrappers: HashMap<String, u32>,
}

impl<'a> EncodingState<'a> {
Expand Down Expand Up @@ -600,6 +630,10 @@ impl<'a> EncodingState<'a> {
// at the end.
self.instantiate_main_module(&shims)?;

// Create any wrappers needed for initializing tasks if task initialization
// exports are present in the main module.
self.create_export_task_initialization_wrappers()?;

// Separate the adapters according which should be instantiated before
// and after indirect lowerings are encoded.
let (before, after) = self
Expand Down Expand Up @@ -688,7 +722,9 @@ impl<'a> EncodingState<'a> {
| Export::GeneralPurposeImportRealloc
| Export::Initialize
| Export::ReallocForAdapter
| Export::IndirectFunctionTable => continue,
| Export::IndirectFunctionTable
| Export::WasmInitTask
| Export::WasmInitAsyncTask => continue,
}
}

Expand Down Expand Up @@ -1014,32 +1050,6 @@ impl<'a> EncodingState<'a> {
name
}
}

fn import_func_name(f: &Function) -> String {
match f.kind {
FunctionKind::Freestanding | FunctionKind::AsyncFreestanding => {
format!("import-func-{}", f.item_name())
}

// transform `[method]foo.bar` into `import-method-foo-bar` to
// have it be a valid kebab-name which can't conflict with
// anything else.
//
// There's probably a better and more "formal" way to do this
// but quick-and-dirty string manipulation should work well
// enough for now hopefully.
FunctionKind::Method(_)
| FunctionKind::AsyncMethod(_)
| FunctionKind::Static(_)
| FunctionKind::AsyncStatic(_)
| FunctionKind::Constructor(_) => {
format!(
"import-{}",
f.name.replace('[', "").replace([']', '.', ' '], "-")
)
}
}
}
}

fn encode_lift(
Expand All @@ -1053,8 +1063,14 @@ impl<'a> EncodingState<'a> {
let resolve = &self.info.encoder.metadata.resolve;
let metadata = self.info.module_metadata_for(module);
let instance_index = self.instance_for(module);
// If we generated an init task wrapper for this export, use that,
// otherwise alias the original export.
let core_func_index =
self.core_alias_export(Some(core_name), instance_index, core_name, ExportKind::Func);
if let Some(&wrapper_idx) = self.export_task_initialization_wrappers.get(core_name) {
wrapper_idx
} else {
self.core_alias_export(Some(core_name), instance_index, core_name, ExportKind::Func)
};
let exports = self.info.exports_for(module);

let options = RequiredOptions::for_export(
Expand Down Expand Up @@ -2189,6 +2205,199 @@ impl<'a> EncodingState<'a> {
.core_alias_export(debug_name, instance, name, kind)
})
}

/// Modules may define `__wasm_init_(async_)task` functions that must be called
/// at the start of every exported function to set up the stack pointer and
/// thread-local storage. To achieve this, we create a wrapper module called
/// `task-init-wrappers` that imports the original exports and the
/// task initialization functions, and defines wrapper functions that call
/// the relevant task initialization function before delegating to the original export.
/// We then instantiate this wrapper module and use its exports as the final
/// exports of the component. If we don't find a `__wasm_init_task` export,
/// we elide the wrapper module entirely.
fn create_export_task_initialization_wrappers(&mut self) -> Result<()> {
let instance_index = self.instance_index.unwrap();
let resolve = &self.info.encoder.metadata.resolve;
let world = &resolve.worlds[self.info.encoder.metadata.world];
let exports = self.info.exports_for(CustomModule::Main);

let wasm_init_task_export = exports.wasm_init_task();
let wasm_init_async_task_export = exports.wasm_init_async_task();
if wasm_init_task_export.is_none() || wasm_init_async_task_export.is_none() {
// __wasm_init_(async_)task was not exported by the main module,
// so no wrappers are needed.
return Ok(());
}
let wasm_init_task = wasm_init_task_export.unwrap();
let wasm_init_async_task = wasm_init_async_task_export.unwrap();

// Collect the exports that we will need to wrap, alongside information
// that we'll need to build the wrappers.
let funcs_to_wrap: Vec<_> = exports
.iter()
.flat_map(|(core_name, export)| match export {
Export::WorldFunc(key, _, abi) => match &world.exports[key] {
WorldItem::Function(f) => Some((core_name, f, abi)),
_ => None,
},
Export::InterfaceFunc(_, id, func_name, abi) => {
let func = &resolve.interfaces[*id].functions[func_name.as_str()];
Some((core_name, func, abi))
}
_ => None,
})
.collect();

if funcs_to_wrap.is_empty() {
// No exports, so no wrappers are needed.
return Ok(());
}

// Now we build the wrapper module
let mut types = TypeSection::new();
let mut imports = ImportSection::new();
let mut functions = FunctionSection::new();
let mut exports_section = ExportSection::new();
let mut code = CodeSection::new();

// Type for __wasm_init_(async_)task: () -> ()
types.ty().function([], []);
let wasm_init_task_type_idx = 0;

// Import __wasm_init_task and __wasm_init_async_task into the wrapper module
imports.import(
"",
wasm_init_task,
EntityType::Function(wasm_init_task_type_idx),
);
imports.import(
"",
wasm_init_async_task,
EntityType::Function(wasm_init_task_type_idx),
);
let wasm_init_task_func_idx = 0u32;
let wasm_init_async_task_func_idx = 1u32;

let mut type_indices = HashMap::new();
let mut next_type_idx = 1u32;
let mut next_func_idx = 2u32;

// First pass: create all types and import all original functions
struct FuncInfo<'a> {
name: &'a str,
type_idx: u32,
orig_func_idx: u32,
is_async: bool,
n_params: usize,
}
let mut func_info = Vec::new();
for &(name, func, abi) in funcs_to_wrap.iter() {
let sig = resolve.wasm_signature(*abi, func);
let type_idx = *type_indices.entry(sig.clone()).or_insert_with(|| {
let idx = next_type_idx;
types.ty().function(
sig.params.iter().map(to_val_type),
sig.results.iter().map(to_val_type),
);
next_type_idx += 1;
idx
});

imports.import("", &import_func_name(func), EntityType::Function(type_idx));
let orig_func_idx = next_func_idx;
next_func_idx += 1;

func_info.push(FuncInfo {
name,
type_idx,
orig_func_idx,
is_async: abi.is_async(),
n_params: sig.params.len(),
});
}

// Second pass: define wrapper functions
for info in func_info.iter() {
let wrapper_func_idx = next_func_idx;
functions.function(info.type_idx);

let mut func = wasm_encoder::Function::new([]);
if info.is_async {
func.instruction(&Instruction::Call(wasm_init_async_task_func_idx));
} else {
func.instruction(&Instruction::Call(wasm_init_task_func_idx));
}
for i in 0..info.n_params as u32 {
func.instruction(&Instruction::LocalGet(i));
}
func.instruction(&Instruction::Call(info.orig_func_idx));
func.instruction(&Instruction::End);
code.function(&func);

exports_section.export(info.name, ExportKind::Func, wrapper_func_idx);
next_func_idx += 1;
}

let mut wrapper_module = Module::new();
wrapper_module.section(&types);
wrapper_module.section(&imports);
wrapper_module.section(&functions);
wrapper_module.section(&exports_section);
wrapper_module.section(&code);

let wrapper_module_idx = self
.component
.core_module(Some("init-task-wrappers"), &wrapper_module);

// Prepare imports for instantiating the wrapper module
let mut wrapper_imports = Vec::new();
let init_idx = self.core_alias_export(
Some(wasm_init_task),
instance_index,
wasm_init_task,
ExportKind::Func,
);
let init_async_idx = self.core_alias_export(
Some(wasm_init_async_task),
instance_index,
wasm_init_async_task,
ExportKind::Func,
);
wrapper_imports.push((wasm_init_task.into(), ExportKind::Func, init_idx));
wrapper_imports.push((
wasm_init_async_task.into(),
ExportKind::Func,
init_async_idx,
));

// Import all original exports to be wrapped
for (name, func, _) in &funcs_to_wrap {
let orig_idx =
self.core_alias_export(Some(name), instance_index, name, ExportKind::Func);
wrapper_imports.push((import_func_name(func), ExportKind::Func, orig_idx));
}

let wrapper_args_idx = self.component.core_instantiate_exports(
Some("init-task-wrappers-args"),
wrapper_imports.iter().map(|(n, k, i)| (n.as_str(), *k, *i)),
);

let wrapper_instance = self.component.core_instantiate(
Some("init-task-wrappers-instance"),
wrapper_module_idx,
[("", ModuleArg::Instance(wrapper_args_idx))],
);

// Map original names to wrapper indices
for (name, _, _) in funcs_to_wrap {
let wrapper_idx =
self.core_alias_export(Some(&name), wrapper_instance, &name, ExportKind::Func);
self.export_task_initialization_wrappers
.insert(name.into(), wrapper_idx);
}

Ok(())
}
}

/// A list of "shims" which start out during the component instantiation process
Expand Down Expand Up @@ -3123,6 +3332,7 @@ impl ComponentEncoder {
exported_instances: Default::default(),
aliased_core_items: Default::default(),
info: &world,
export_task_initialization_wrappers: HashMap::new(),
};
state.encode_imports(&self.import_name_map)?;
state.encode_core_modules();
Expand Down
Loading