diff --git a/src/main.rs b/src/main.rs index 02a0eda96..619239091 100644 --- a/src/main.rs +++ b/src/main.rs @@ -28,6 +28,7 @@ use structopt::StructOpt; mod macros; pub mod apis; pub mod common; +mod recovery; pub mod sighandler; #[cfg(test)] mod spectests; @@ -95,9 +96,7 @@ fn execute_wasm(wasm_path: PathBuf) -> Result<(), String> { Some(&webassembly::Export::Function(index)) => index, _ => panic!("Main function not found"), }); - let main: extern "C" fn(&webassembly::Instance) = - get_instance_function!(instance, func_index); - main(&instance); + instance.start_func(func_index).unwrap(); } Ok(()) diff --git a/src/recovery.rs b/src/recovery.rs new file mode 100644 index 000000000..11c1dea82 --- /dev/null +++ b/src/recovery.rs @@ -0,0 +1,41 @@ +//! When a WebAssembly module triggers any traps, we perform recovery here. + +use std::cell::UnsafeCell; + +extern "C" { + fn setjmp(env: *mut ::nix::libc::c_void) -> ::nix::libc::c_int; + fn longjmp(env: *mut ::nix::libc::c_void, val: ::nix::libc::c_int) -> !; +} + +const SETJMP_BUFFER_LEN: usize = 27; + +thread_local! { + static SETJMP_BUFFER: UnsafeCell<[::nix::libc::c_int; SETJMP_BUFFER_LEN]> = UnsafeCell::new([0; SETJMP_BUFFER_LEN]); +} + +/// Calls a function with longjmp receiver installed. The function must be compiled from WebAssembly; +/// Otherwise, the behavior is undefined. +pub unsafe fn protected_call(f: fn(T) -> R, p: T) -> Result { + let jmp_buf = SETJMP_BUFFER.with(|buf| buf.get()); + let prev_jmp_buf = *jmp_buf; + + let signum = setjmp(jmp_buf as *mut ::nix::libc::c_void); + if signum != 0 { + *jmp_buf = prev_jmp_buf; + Err(signum) + } else { + let ret = f(p); + *jmp_buf = prev_jmp_buf; + Ok(ret) + } +} + +/// Unwinds to last protected_call. +pub unsafe fn do_unwind(signum: i32) -> ! { + let jmp_buf = SETJMP_BUFFER.with(|buf| buf.get()); + if *jmp_buf == [0; SETJMP_BUFFER_LEN] { + ::std::process::abort(); + } + + longjmp(jmp_buf as *mut ::nix::libc::c_void, signum) +} diff --git a/src/sighandler.rs b/src/sighandler.rs index d75fb179b..c02c4adb2 100644 --- a/src/sighandler.rs +++ b/src/sighandler.rs @@ -4,12 +4,11 @@ //! //! Please read more about this here: https://github.com/CraneStation/wasmtime/issues/15 //! This code is inspired by: https://github.com/pepyakin/wasmtime/commit/625a2b6c0815b21996e111da51b9664feb174622 +use super::recovery; use nix::sys::signal::{ - sigaction, SaFlags, SigAction, SigHandler, SigSet, Signal, SIGBUS, SIGFPE, SIGILL, SIGSEGV, + sigaction, SaFlags, SigAction, SigHandler, SigSet, SIGBUS, SIGFPE, SIGILL, SIGSEGV, }; -static mut SETJMP_BUFFER: [::nix::libc::c_int; 27] = [0; 27]; - pub unsafe fn install_sighandler() { let sa = SigAction::new( SigHandler::Handler(signal_trap_handler), @@ -20,29 +19,10 @@ pub unsafe fn install_sighandler() { sigaction(SIGILL, &sa).unwrap(); sigaction(SIGSEGV, &sa).unwrap(); sigaction(SIGBUS, &sa).unwrap(); - let signum = setjmp((&mut SETJMP_BUFFER[..]).as_mut_ptr() as *mut ::nix::libc::c_void); - if signum != 0 { - let signal = Signal::from_c_int(signum).unwrap(); - match signal { - SIGFPE => panic!("floating-point exception"), - SIGILL => panic!("illegal instruction"), - SIGSEGV => panic!("segmentation violation"), - SIGBUS => panic!("bus error"), - _ => panic!("signal error: {:?}", signal), - }; - } -} - -extern "C" { - fn setjmp(env: *mut ::nix::libc::c_void) -> ::nix::libc::c_int; - fn longjmp(env: *mut ::nix::libc::c_void, val: ::nix::libc::c_int); } extern "C" fn signal_trap_handler(signum: ::nix::libc::c_int) { unsafe { - longjmp( - (&mut SETJMP_BUFFER).as_mut_ptr() as *mut ::nix::libc::c_void, - signum, - ); + recovery::do_unwind(signum); } } diff --git a/src/webassembly/instance.rs b/src/webassembly/instance.rs index c1c5c7637..b98901f54 100644 --- a/src/webassembly/instance.rs +++ b/src/webassembly/instance.rs @@ -22,6 +22,7 @@ use std::slice; use std::sync::Arc; use super::super::common::slice::{BoundedSlice, UncheckedSlice}; +use super::super::recovery; use super::errors::ErrorKind; use super::import_object::{ImportObject, ImportValue}; use super::math_intrinsics; @@ -225,7 +226,8 @@ impl Instance { // let r = *Arc::from_raw(isa_ptr); compile_function(&*options.isa, function_body).unwrap() // unimplemented!() - }).collect(); + }) + .collect(); for compiled_func in compiled_funcs.into_iter() { let CompiledFunction { @@ -475,7 +477,8 @@ impl Instance { &mem[..], mem.current as usize * LinearMemory::WASM_PAGE_SIZE, ) - }).collect(); + }) + .collect(); let globals_pointer: GlobalsSlice = globals[..].into(); let data_pointers = DataPointers { @@ -513,10 +516,16 @@ impl Instance { get_function_addr(&func_index, &self.import_functions, &self.functions) } - pub fn start(&self) { + pub fn start_func(&self, func_index: FuncIndex) -> Result<(), i32> { + let func: fn(&Instance) = get_instance_function!(&self, func_index); + unsafe { recovery::protected_call(func, self) } + } + + pub fn start(&self) -> Result<(), i32> { if let Some(func_index) = self.start_func { - let func: fn(&Instance) = get_instance_function!(&self, func_index); - func(self) + self.start_func(func_index) + } else { + panic!("start func not found") } }