diff --git a/crates/wasmi/benches/bench/mod.rs b/crates/wasmi/benches/bench/mod.rs index b47df4bc45..9f9bf33a13 100644 --- a/crates/wasmi/benches/bench/mod.rs +++ b/crates/wasmi/benches/bench/mod.rs @@ -23,6 +23,7 @@ pub fn load_wasm_from_file(file_name: &str) -> Vec { /// Returns a [`Config`] useful for benchmarking. fn bench_config() -> Config { let mut config = Config::default(); + config.wasm_tail_call(true); config.set_stack_limits(StackLimits::new(1024, 1024 * 1024, 64 * 1024).unwrap()); config } diff --git a/crates/wasmi/benches/wat/fibonacci.wat b/crates/wasmi/benches/wat/fibonacci.wat index eb65fc5a58..4f095caa87 100644 --- a/crates/wasmi/benches/wat/fibonacci.wat +++ b/crates/wasmi/benches/wat/fibonacci.wat @@ -27,7 +27,7 @@ (return (local.get $b)) ) ) - (call $fib_tail_recursive + (return_call $fib_tail_recursive (i64.sub (local.get $N) (i64.const 1)) (local.get $b) (i64.add (local.get $a) (local.get $b)) @@ -35,7 +35,7 @@ ) (func (export "fibonacci_tail") (param $N i64) (result i64) - (call $fib_tail_recursive (local.get $N) (i64.const 0) (i64.const 1)) + (return_call $fib_tail_recursive (local.get $N) (i64.const 0) (i64.const 1)) ) (func $fib_iterative (export "fibonacci_iter") (param $N i64) (result i64) diff --git a/crates/wasmi/src/engine/bytecode/mod.rs b/crates/wasmi/src/engine/bytecode/mod.rs index 17a111197d..858409d001 100644 --- a/crates/wasmi/src/engine/bytecode/mod.rs +++ b/crates/wasmi/src/engine/bytecode/mod.rs @@ -53,6 +53,15 @@ pub enum Instruction { }, Return(DropKeep), ReturnIfNez(DropKeep), + ReturnCall { + drop_keep: DropKeep, + func: FuncIdx, + }, + ReturnCallIndirect { + drop_keep: DropKeep, + table: TableIdx, + func_type: SignatureIdx, + }, Call(FuncIdx), CallIndirect { table: TableIdx, diff --git a/crates/wasmi/src/engine/config.rs b/crates/wasmi/src/engine/config.rs index 0bfdd7416f..44a4a811c1 100644 --- a/crates/wasmi/src/engine/config.rs +++ b/crates/wasmi/src/engine/config.rs @@ -25,6 +25,8 @@ pub struct Config { bulk_memory: bool, /// Is `true` if the [`reference-types`] Wasm proposal is enabled. reference_types: bool, + /// Is `true` if the [`tail-call`] Wasm proposal is enabled. + tail_call: bool, /// Is `true` if Wasm instructions on `f32` and `f64` types are allowed. floats: bool, /// Is `true` if `wasmi` executions shall consume fuel. @@ -94,6 +96,7 @@ impl Default for Config { multi_value: true, bulk_memory: true, reference_types: true, + tail_call: false, floats: true, consume_fuel: false, fuel_costs: FuelCosts::default(), @@ -201,6 +204,18 @@ impl Config { self } + /// Enable or disable the [`tail-call`] Wasm proposal for the [`Config`]. + /// + /// # Note + /// + /// Disabled by default. + /// + /// [`tail-call`]: https://github.com/WebAssembly/tail-calls + pub fn wasm_tail_call(&mut self, enable: bool) -> &mut Self { + self.tail_call = enable; + self + } + /// Enable or disable Wasm floating point (`f32` and `f64`) instructions and types. /// /// Enabled by default. @@ -252,12 +267,12 @@ impl Config { sign_extension: self.sign_extension, bulk_memory: self.bulk_memory, reference_types: self.reference_types, + tail_call: self.tail_call, floats: self.floats, component_model: false, simd: false, relaxed_simd: false, threads: false, - tail_call: false, multi_memory: false, exceptions: false, memory64: false, diff --git a/crates/wasmi/src/engine/executor.rs b/crates/wasmi/src/engine/executor.rs index 4052e5f470..3b95c5dde3 100644 --- a/crates/wasmi/src/engine/executor.rs +++ b/crates/wasmi/src/engine/executor.rs @@ -25,6 +25,7 @@ use crate::{ table::TableEntity, Func, FuncRef, + Instance, StoreInner, Table, }; @@ -42,7 +43,7 @@ pub enum WasmOutcome { /// The Wasm execution has ended and returns to the host side. Return, /// The Wasm execution calls a host function. - Call(Func), + Call { host_func: Func, instance: Instance }, } /// The outcome of a Wasm execution. @@ -56,7 +57,16 @@ pub enum CallOutcome { /// The Wasm execution continues in Wasm. Continue, /// The Wasm execution calls a host function. - Call(Func), + Call { host_func: Func, instance: Instance }, +} + +/// The kind of a function call. +#[derive(Debug, Copy, Clone)] +pub enum CallKind { + /// A nested function call. + Nested, + /// A tailing function call. + Tail, } /// The outcome of a Wasm return statement. @@ -203,16 +213,56 @@ impl<'ctx, 'engine> Executor<'ctx, 'engine> { return Ok(WasmOutcome::Return); } } + Instr::ReturnCall { drop_keep, func } => { + if let CallOutcome::Call { + host_func, + instance, + } = self.visit_return_call(drop_keep, func)? + { + return Ok(WasmOutcome::Call { + host_func, + instance, + }); + } + } + Instr::ReturnCallIndirect { + drop_keep, + table, + func_type, + } => { + if let CallOutcome::Call { + host_func, + instance, + } = self.visit_return_call_indirect(drop_keep, table, func_type)? + { + return Ok(WasmOutcome::Call { + host_func, + instance, + }); + } + } Instr::Call(func) => { - if let CallOutcome::Call(host_func) = self.visit_call(func)? { - return Ok(WasmOutcome::Call(host_func)); + if let CallOutcome::Call { + host_func, + instance, + } = self.visit_call(func)? + { + return Ok(WasmOutcome::Call { + host_func, + instance, + }); } } Instr::CallIndirect { table, func_type } => { - if let CallOutcome::Call(host_func) = - self.visit_call_indirect(table, func_type)? + if let CallOutcome::Call { + host_func, + instance, + } = self.visit_call_indirect(table, func_type)? { - return Ok(WasmOutcome::Call(host_func)); + return Ok(WasmOutcome::Call { + host_func, + instance, + }); } } Instr::Drop => self.visit_drop(), @@ -519,24 +569,30 @@ impl<'ctx, 'engine> Executor<'ctx, 'engine> { /// the function call so that the stack and execution state is synchronized /// with the outer structures. #[inline(always)] - fn call_func(&mut self, func: &Func) -> Result { + fn call_func(&mut self, func: &Func, kind: CallKind) -> Result { self.next_instr(); self.sync_stack_ptr(); - self.call_stack - .push(FuncFrame::new(self.ip, self.cache.instance()))?; - let wasm_func = match self.ctx.resolve_func(func) { - FuncEntity::Wasm(wasm_func) => wasm_func, + if matches!(kind, CallKind::Nested) { + self.call_stack + .push(FuncFrame::new(self.ip, self.cache.instance()))?; + } + match self.ctx.resolve_func(func) { + FuncEntity::Wasm(wasm_func) => { + let header = self.code_map.header(wasm_func.func_body()); + self.value_stack.prepare_wasm_call(header)?; + self.sp = self.value_stack.stack_ptr(); + self.cache.update_instance(wasm_func.instance()); + self.ip = self.code_map.instr_ptr(header.iref()); + Ok(CallOutcome::Continue) + } FuncEntity::Host(_host_func) => { self.cache.reset(); - return Ok(CallOutcome::Call(*func)); + Ok(CallOutcome::Call { + host_func: *func, + instance: *self.cache.instance(), + }) } - }; - let header = self.code_map.header(wasm_func.func_body()); - self.value_stack.prepare_wasm_call(header)?; - self.sp = self.value_stack.stack_ptr(); - self.cache.update_instance(wasm_func.instance()); - self.ip = self.code_map.instr_ptr(header.iref()); - Ok(CallOutcome::Continue) + } } /// Returns to the caller. @@ -608,6 +664,48 @@ impl<'ctx, 'engine> Executor<'ctx, 'engine> { fn fuel_costs(&self) -> &FuelCosts { self.ctx.engine().config().fuel_costs() } + + /// Executes a `call` or `return_call` instruction. + #[inline(always)] + fn execute_call( + &mut self, + func_index: FuncIdx, + kind: CallKind, + ) -> Result { + let callee = self.cache.get_func(self.ctx, func_index); + self.call_func(&callee, kind) + } + + /// Executes a `call_indirect` or `return_call_indirect` instruction. + #[inline(always)] + fn execute_call_indirect( + &mut self, + table: TableIdx, + func_index: u32, + func_type: SignatureIdx, + kind: CallKind, + ) -> Result { + let table = self.cache.get_table(self.ctx, table); + let funcref = self + .ctx + .resolve_table(&table) + .get_untyped(func_index) + .map(FuncRef::from) + .ok_or(TrapCode::TableOutOfBounds)?; + let func = funcref.func().ok_or(TrapCode::IndirectCallToNull)?; + let actual_signature = self.ctx.resolve_func(func).ty_dedup(); + let expected_signature = self + .ctx + .resolve_instance(self.cache.instance()) + .get_signature(func_type.into_inner()) + .unwrap_or_else(|| { + panic!("missing signature for call_indirect at index: {func_type:?}") + }); + if actual_signature != expected_signature { + return Err(TrapCode::BadSignature).map_err(Into::into); + } + self.call_func(func, kind) + } } impl<'ctx, 'engine> Executor<'ctx, 'engine> { @@ -712,10 +810,32 @@ impl<'ctx, 'engine> Executor<'ctx, 'engine> { self.next_instr() } + #[inline(always)] + fn visit_return_call( + &mut self, + drop_keep: DropKeep, + func_index: FuncIdx, + ) -> Result { + self.sp.drop_keep(drop_keep); + self.execute_call(func_index, CallKind::Tail) + } + + #[inline(always)] + fn visit_return_call_indirect( + &mut self, + drop_keep: DropKeep, + table: TableIdx, + func_type: SignatureIdx, + ) -> Result { + let func_index: u32 = self.sp.pop_as(); + self.sp.drop_keep(drop_keep); + self.execute_call_indirect(table, func_index, func_type, CallKind::Tail) + } + #[inline(always)] fn visit_call(&mut self, func_index: FuncIdx) -> Result { let callee = self.cache.get_func(self.ctx, func_index); - self.call_func(&callee) + self.call_func(&callee, CallKind::Nested) } #[inline(always)] @@ -725,26 +845,7 @@ impl<'ctx, 'engine> Executor<'ctx, 'engine> { func_type: SignatureIdx, ) -> Result { let func_index: u32 = self.sp.pop_as(); - let table = self.cache.get_table(self.ctx, table); - let funcref = self - .ctx - .resolve_table(&table) - .get_untyped(func_index) - .map(FuncRef::from) - .ok_or(TrapCode::TableOutOfBounds)?; - let func = funcref.func().ok_or(TrapCode::IndirectCallToNull)?; - let actual_signature = self.ctx.resolve_func(func).ty_dedup(); - let expected_signature = self - .ctx - .resolve_instance(self.cache.instance()) - .get_signature(func_type.into_inner()) - .unwrap_or_else(|| { - panic!("missing signature for call_indirect at index: {func_type:?}") - }); - if actual_signature != expected_signature { - return Err(TrapCode::BadSignature).map_err(Into::into); - } - self.call_func(func) + self.execute_call_indirect(table, func_index, func_type, CallKind::Nested) } #[inline(always)] diff --git a/crates/wasmi/src/engine/func_builder/mod.rs b/crates/wasmi/src/engine/func_builder/mod.rs index f05292360b..1c161ee43c 100644 --- a/crates/wasmi/src/engine/func_builder/mod.rs +++ b/crates/wasmi/src/engine/func_builder/mod.rs @@ -135,6 +135,9 @@ macro_rules! impl_visit_operator { ( @reference_types $($rest:tt)* ) => { impl_visit_operator!(@@supported $($rest)*); }; + ( @tail_call $($rest:tt)* ) => { + impl_visit_operator!(@@supported $($rest)*); + }; ( @@supported $op:ident $({ $($arg:ident: $argty:ty),* })? => $visit:ident $($rest:tt)* ) => { fn $visit(&mut self $($(,$arg: $argty)*)?) -> Self::Output { let offset = self.current_pos(); diff --git a/crates/wasmi/src/engine/func_builder/translator.rs b/crates/wasmi/src/engine/func_builder/translator.rs index e1e24aeadb..9a52d8f5a3 100644 --- a/crates/wasmi/src/engine/func_builder/translator.rs +++ b/crates/wasmi/src/engine/func_builder/translator.rs @@ -274,6 +274,24 @@ impl<'parser> FuncTranslator<'parser> { Ok(()) } + /// Return the value stack height difference to the height at the given `depth`. + /// + /// # Panics + /// + /// - If the current code is unreachable. + fn height_diff(&self, depth: u32) -> u32 { + debug_assert!(self.is_reachable()); + let current_height = self.stack_height.height(); + let frame = self.alloc.control_frames.nth_back(depth); + let origin_height = frame.stack_height().expect("frame is reachable"); + assert!( + origin_height <= current_height, + "encountered value stack underflow: \ + current height {current_height}, original height {origin_height}", + ); + current_height - origin_height + } + /// Computes how many values should be dropped and kept for the specific branch. /// /// # Panics @@ -290,14 +308,7 @@ impl<'parser> FuncTranslator<'parser> { ControlFrameKind::Loop => frame.block_type().len_params(self.res.engine()), }; // Find out how many values we need to drop. - let current_height = self.stack_height.height(); - let origin_height = frame.stack_height().expect("frame is reachable"); - assert!( - origin_height <= current_height, - "encountered value stack underflow: \ - current height {current_height}, original height {origin_height}", - ); - let height_diff = current_height - origin_height; + let height_diff = self.height_diff(depth); assert!( keep <= height_diff, "tried to keep {keep} values while having \ @@ -307,6 +318,15 @@ impl<'parser> FuncTranslator<'parser> { DropKeep::new(drop as usize, keep as usize).map_err(Into::into) } + /// Returns the maximum control stack depth at the current position in the code. + fn max_depth(&self) -> u32 { + self.alloc + .control_frames + .len() + .checked_sub(1) + .expect("control flow frame stack must not be empty") as u32 + } + /// Compute [`DropKeep`] for the return statement. /// /// # Panics @@ -319,12 +339,7 @@ impl<'parser> FuncTranslator<'parser> { !self.alloc.control_frames.is_empty(), "drop_keep_return cannot be called with the frame stack empty" ); - let max_depth = self - .alloc - .control_frames - .len() - .checked_sub(1) - .expect("control flow frame stack must not be empty") as u32; + let max_depth = self.max_depth(); let drop_keep = self.compute_drop_keep(max_depth)?; let len_params_locals = self.locals.len_registered() as usize; DropKeep::new( @@ -694,6 +709,29 @@ impl<'parser> FuncTranslator<'parser> { fn unsupported_operator(&self, name: &str) -> Result<(), TranslationError> { panic!("tried to translate an unsupported Wasm operator: {name}") } + + /// Computes how many values should be dropped and kept for the return call. + /// + /// # Panics + /// + /// If underflow of the value stack is detected. + fn drop_keep_return_call(&self, callee_type: &FuncType) -> Result { + debug_assert!(self.is_reachable()); + // For return calls we need to adjust the `keep` value to + // be equal to the amount of parameters the callee expects. + let keep = callee_type.params().len() as u32; + // Find out how many values we need to drop. + let max_depth = self.max_depth(); + let height_diff = self.height_diff(max_depth); + assert!( + keep <= height_diff, + "tried to keep {keep} values while having \ + only {height_diff} values available on the frame", + ); + let len_params_locals = self.locals.len_registered(); + let drop = height_diff - keep + len_params_locals; + DropKeep::new(drop as usize, keep as usize).map_err(Into::into) + } } /// An acquired target. @@ -729,6 +767,9 @@ macro_rules! impl_visit_operator { ( @reference_types $($rest:tt)* ) => { impl_visit_operator!(@@skipped $($rest)*); }; + ( @tail_call $($rest:tt)* ) => { + impl_visit_operator!(@@skipped $($rest)*); + }; ( @@skipped $op:ident $({ $($arg:ident: $argty:ty),* })? => $visit:ident $($rest:tt)* ) => { // We skip Wasm operators that we already implement manually. impl_visit_operator!($($rest)*); @@ -1087,6 +1128,48 @@ impl<'a> VisitOperator<'a> for FuncTranslator<'a> { }) } + fn visit_return_call(&mut self, func_idx: u32) -> Result<(), TranslationError> { + self.translate_if_reachable(|builder| { + let func = bytecode::FuncIdx::from(func_idx); + let func_type = builder.func_type_of(func_idx.into()); + let drop_keep = builder.drop_keep_return_call(&func_type)?; + builder.bump_fuel_consumption(builder.fuel_costs().call); + builder.bump_fuel_consumption(drop_keep.fuel_consumption(builder.fuel_costs())); + builder + .alloc + .inst_builder + .push_inst(Instruction::ReturnCall { drop_keep, func }); + builder.reachable = false; + Ok(()) + }) + } + + fn visit_return_call_indirect( + &mut self, + func_type_index: u32, + table_index: u32, + ) -> Result<(), TranslationError> { + self.translate_if_reachable(|builder| { + let signature = SignatureIdx::from(func_type_index); + let func_type = builder.func_type_at(signature); + let table = TableIdx::from(table_index); + builder.stack_height.pop1(); + let drop_keep = builder.drop_keep_return_call(&func_type)?; + builder.bump_fuel_consumption(builder.fuel_costs().call); + builder.bump_fuel_consumption(drop_keep.fuel_consumption(builder.fuel_costs())); + builder + .alloc + .inst_builder + .push_inst(Instruction::ReturnCallIndirect { + drop_keep, + table, + func_type: signature, + }); + builder.reachable = false; + Ok(()) + }) + } + fn visit_call(&mut self, func_idx: u32) -> Result<(), TranslationError> { self.translate_if_reachable(|builder| { builder.bump_fuel_consumption(builder.fuel_costs().call); @@ -1110,18 +1193,14 @@ impl<'a> VisitOperator<'a> for FuncTranslator<'a> { ) -> Result<(), TranslationError> { self.translate_if_reachable(|builder| { builder.bump_fuel_consumption(builder.fuel_costs().call); - let func_type_index = SignatureIdx::from(func_type_index); + let func_type = SignatureIdx::from(func_type_index); let table = TableIdx::from(table_index); builder.stack_height.pop1(); - let func_type = builder.func_type_at(func_type_index); - builder.adjust_value_stack_for_call(&func_type); + builder.adjust_value_stack_for_call(&builder.func_type_at(func_type)); builder .alloc .inst_builder - .push_inst(Instruction::CallIndirect { - table, - func_type: func_type_index, - }); + .push_inst(Instruction::CallIndirect { table, func_type }); Ok(()) }) } diff --git a/crates/wasmi/src/engine/mod.rs b/crates/wasmi/src/engine/mod.rs index 32f56476b2..8d1e598bba 100644 --- a/crates/wasmi/src/engine/mod.rs +++ b/crates/wasmi/src/engine/mod.rs @@ -677,14 +677,35 @@ impl<'engine> EngineExecutor<'engine> { loop { match self.execute_wasm(ctx.as_context_mut(), &mut cache)? { WasmOutcome::Return => return Ok(()), - WasmOutcome::Call(ref func) => { + WasmOutcome::Call { + ref host_func, + instance, + } => { + let func = host_func; let host_func = match ctx.as_context().store.inner.resolve_func(func) { FuncEntity::Wasm(_) => unreachable!("`func` must be a host function"), FuncEntity::Host(host_func) => *host_func, }; - self.stack - .call_host_from_wasm(ctx.as_context_mut(), host_func, &self.res.func_types) - .map_err(|trap| TaggedTrap::host(*func, trap))?; + let result = self.stack.call_host_impl( + ctx.as_context_mut(), + host_func, + Some(&instance), + &self.res.func_types, + ); + if self.stack.frames.peek().is_some() { + // Case: There is a frame on the call stack. + // + // This is the default case and we can easily make host function + // errors return a resumable call handle. + result.map_err(|trap| TaggedTrap::host(*func, trap))?; + } else { + // Case: No frame is on the call stack. (edge case) + // + // This can happen if the host function was called by a tail call. + // In this case we treat host function errors the same as if we called + // the host function as root and do not allow to resume the call. + result.map_err(TaggedTrap::Wasm)?; + } } } } diff --git a/crates/wasmi/src/engine/stack/mod.rs b/crates/wasmi/src/engine/stack/mod.rs index 7cffa982cf..81b6dfc991 100644 --- a/crates/wasmi/src/engine/stack/mod.rs +++ b/crates/wasmi/src/engine/stack/mod.rs @@ -5,9 +5,9 @@ pub use self::{ frames::{CallStack, FuncFrame}, values::{ValueStack, ValueStackPtr}, }; -use super::{code_map::CodeMap, func_types::FuncTypeRegistry, FuncParams}; use crate::{ core::UntypedValue, + engine::{code_map::CodeMap, func_types::FuncTypeRegistry, FuncParams}, func::{HostFuncEntity, WasmFuncEntity}, AsContext, Instance, @@ -164,23 +164,6 @@ impl Stack { self.call_host_impl(ctx, host_func, None, func_types) } - /// Executes the given host function called by a Wasm function. - #[inline(always)] - pub fn call_host_from_wasm( - &mut self, - ctx: StoreContextMut, - host_func: HostFuncEntity, - func_types: &FuncTypeRegistry, - ) -> Result<(), Trap> { - let caller = self - .frames - .peek() - .copied() - .expect("must have a frame on the call stack"); - let instance = caller.instance(); - self.call_host_impl(ctx, host_func, Some(instance), func_types) - } - /// Executes the given host function. /// /// # Errors @@ -188,7 +171,7 @@ impl Stack { /// - If the host function returns a host side error or trap. /// - If the value stack overflowed upon pushing parameters or results. #[inline(always)] - fn call_host_impl( + pub fn call_host_impl( &mut self, ctx: StoreContextMut, host_func: HostFuncEntity, diff --git a/crates/wasmi/tests/e2e/v1/resumable_call.rs b/crates/wasmi/tests/e2e/v1/resumable_call.rs index a264b9719d..d2d9d73824 100644 --- a/crates/wasmi/tests/e2e/v1/resumable_call.rs +++ b/crates/wasmi/tests/e2e/v1/resumable_call.rs @@ -3,6 +3,7 @@ use core::slice; use wasmi::{ + Config, Engine, Error, Extern, @@ -13,6 +14,7 @@ use wasmi::{ ResumableCall, ResumableInvocation, Store, + TypedFunc, TypedResumableCall, TypedResumableInvocation, Value, @@ -20,13 +22,15 @@ use wasmi::{ use wasmi_core::{Trap, TrapCode, ValueType}; fn test_setup() -> (Store<()>, Linker<()>) { - let engine = Engine::default(); + let mut config = Config::default(); + config.wasm_tail_call(true); + let engine = Engine::new(&config); let store = Store::new(&engine, ()); let linker = >::new(&engine); (store, linker) } -fn resumable_call_smoldot_common(wasm: &str) -> (Store<()>, TypedResumableInvocation) { +fn resumable_call_smoldot_common(wasm: &str) -> (Store<()>, TypedFunc<(), i32>) { let (mut store, mut linker) = test_setup(); // The important part about this test is that this // host function has more results than parameters. @@ -47,16 +51,29 @@ fn resumable_call_smoldot_common(wasm: &str) -> (Store<()>, TypedResumableInvoca .start(&mut store) .unwrap(); let wasm_fn = instance.get_typed_func::<(), i32>(&store, "test").unwrap(); - let invocation = match wasm_fn.call_resumable(&mut store, ()).unwrap() { - TypedResumableCall::Resumable(invocation) => invocation, - TypedResumableCall::Finished(_) => panic!("expected TypedResumableCall::Resumable"), - }; - (store, invocation) + (store, wasm_fn) +} + +pub trait UnwrapResumable { + type Results; + + fn unwrap_resumable(self) -> TypedResumableInvocation; +} + +impl UnwrapResumable for Result, Trap> { + type Results = Results; + + fn unwrap_resumable(self) -> TypedResumableInvocation { + match self.unwrap() { + TypedResumableCall::Resumable(invocation) => invocation, + TypedResumableCall::Finished(_) => panic!("expected TypedResumableCall::Resumable"), + } + } } #[test] fn resumable_call_smoldot_01() { - let (mut store, invocation) = resumable_call_smoldot_common( + let (mut store, wasm_fn) = resumable_call_smoldot_common( r#" (module (import "env" "host_fn" (func $host_fn (result i32))) @@ -66,6 +83,50 @@ fn resumable_call_smoldot_01() { ) "#, ); + let invocation = wasm_fn.call_resumable(&mut store, ()).unwrap_resumable(); + match invocation.resume(&mut store, &[Value::I32(42)]).unwrap() { + TypedResumableCall::Finished(result) => assert_eq!(result, 42), + TypedResumableCall::Resumable(_) => panic!("expected TypeResumableCall::Finished"), + } +} + +#[test] +fn resumable_call_smoldot_tail_01() { + let (mut store, wasm_fn) = resumable_call_smoldot_common( + r#" + (module + (import "env" "host_fn" (func $host_fn (result i32))) + (func (export "test") (result i32) + (return_call $host_fn) + ) + ) + "#, + ); + assert_eq!( + wasm_fn + .call_resumable(&mut store, ()) + .unwrap_err() + .i32_exit_status(), + Some(100), + ); +} + +#[test] +fn resumable_call_smoldot_tail_02() { + let (mut store, wasm_fn) = resumable_call_smoldot_common( + r#" + (module + (import "env" "host_fn" (func $host (result i32))) + (func $wasm (result i32) + (return_call $host) + ) + (func (export "test") (result i32) + (call $wasm) + ) + ) + "#, + ); + let invocation = wasm_fn.call_resumable(&mut store, ()).unwrap_resumable(); match invocation.resume(&mut store, &[Value::I32(42)]).unwrap() { TypedResumableCall::Finished(result) => assert_eq!(result, 42), TypedResumableCall::Resumable(_) => panic!("expected TypeResumableCall::Finished"), @@ -74,7 +135,7 @@ fn resumable_call_smoldot_01() { #[test] fn resumable_call_smoldot_02() { - let (mut store, invocation) = resumable_call_smoldot_common( + let (mut store, wasm_fn) = resumable_call_smoldot_common( r#" (module (import "env" "host_fn" (func $host_fn (result i32))) @@ -91,6 +152,7 @@ fn resumable_call_smoldot_02() { ) "#, ); + let invocation = wasm_fn.call_resumable(&mut store, ()).unwrap_resumable(); match invocation.resume(&mut store, &[Value::I32(42)]).unwrap() { TypedResumableCall::Finished(result) => assert_eq!(result, 11), TypedResumableCall::Resumable(_) => panic!("expected TypeResumableCall::Finished"), diff --git a/crates/wasmi/tests/spec/mod.rs b/crates/wasmi/tests/spec/mod.rs index 6324967851..5475b4b4d1 100644 --- a/crates/wasmi/tests/spec/mod.rs +++ b/crates/wasmi/tests/spec/mod.rs @@ -77,6 +77,7 @@ fn make_config() -> Config { config.wasm_multi_value(true); config.wasm_bulk_memory(true); config.wasm_reference_types(true); + config.wasm_tail_call(true); config } @@ -95,6 +96,9 @@ define_spec_tests! { fn wasm_bulk("bulk"); fn wasm_call("call"); fn wasm_call_indirect("call_indirect"); + fn wasm_return_call("proposals/tail-call/return_call"); + + fn wasm_return_call_indirect("proposals/tail-call/return_call_indirect"); fn wasm_comments("comments"); fn wasm_const("const"); fn wasm_conversions("conversions");