diff --git a/crates/wit/src/parse_macro_input.rs b/crates/wit/src/parse_macro_input.rs index 9e866e4..87389d5 100644 --- a/crates/wit/src/parse_macro_input.rs +++ b/crates/wit/src/parse_macro_input.rs @@ -17,6 +17,7 @@ mod item_fn; mod item_foreign_mod; mod item_record; +mod utils; use crate::fce_ast_types::FCEAst; diff --git a/crates/wit/src/parse_macro_input/item_fn.rs b/crates/wit/src/parse_macro_input/item_fn.rs index a768e8d..4bcb848 100644 --- a/crates/wit/src/parse_macro_input/item_fn.rs +++ b/crates/wit/src/parse_macro_input/item_fn.rs @@ -20,7 +20,6 @@ use crate::ParsedType; use crate::fce_ast_types::FCEAst; use crate::fce_ast_types::AstFunctionItem; use crate::fce_ast_types::AstFuncArgument; -use crate::parsed_type::passing_style_of; use crate::syn_error; use syn::Result; @@ -148,23 +147,10 @@ fn check_parsed_functions<'a>( /// Vec>> => true /// &Vec => false fn contains_inner_ref(ty: &ParsedType) -> bool { - fn contains_inner_ref_impl(ty: &ParsedType) -> bool { - use crate::parsed_type::PassingStyle; - - match ty { - ParsedType::Vector(ty, passing_style) => match passing_style { - PassingStyle::ByValue => contains_inner_ref_impl(ty), - PassingStyle::ByRef | PassingStyle::ByMutRef => true, - }, - _ => match passing_style_of(ty) { - PassingStyle::ByValue => false, - PassingStyle::ByRef | PassingStyle::ByMutRef => true, - }, - } - } + use super::utils::contain_inner_ref; match ty { - ParsedType::Vector(ty, _) => contains_inner_ref_impl(ty), + ParsedType::Vector(ty, _) => contain_inner_ref(ty), // Structs are checked while parsing _ => false, } diff --git a/crates/wit/src/parse_macro_input/item_foreign_mod.rs b/crates/wit/src/parse_macro_input/item_foreign_mod.rs index a1e8190..68d6c4d 100644 --- a/crates/wit/src/parse_macro_input/item_foreign_mod.rs +++ b/crates/wit/src/parse_macro_input/item_foreign_mod.rs @@ -28,65 +28,105 @@ const WASM_IMPORT_MODULE_DIRECTIVE_NAME: &str = "wasm_import_module"; impl ParseMacroInput for syn::ItemForeignMod { fn parse_macro_input(self) -> Result { - match &self.abi.name { - Some(name) if name.value() != "C" => { - return syn_error!(self.span(), "only 'C' abi is allowed") - } - _ => {} + check_foreign_section(&self)?; + + let wasm_import_module: Option = parse_wasm_import_module(&self); + let namespace = try_extract_namespace(wasm_import_module, &self)?; + + let imports = extract_import_functions(&self)?; + check_imports(imports.iter().zip(self.items.iter().map(|i| i.span())))?; + + let extern_mod_item = fce_ast_types::AstExternModItem { + namespace, + imports, + original: Some(self), }; + Ok(FCEAst::ExternMod(extern_mod_item)) + } +} - let self_span = self.span(); +fn check_foreign_section(foreign_mod: &syn::ItemForeignMod) -> Result<()> { + match &foreign_mod.abi.name { + Some(name) if name.value() != "C" => { + syn_error!(foreign_mod.span(), "only 'C' abi is allowed") + } + _ => Ok(()), + } +} - let imports = self - .items - .iter() - .cloned() - .map(parse_raw_foreign_item) - .collect::>()?; +/// Tries to find and parse wasm module name from +/// #[link(wasm_import_module = "host")] +fn parse_wasm_import_module(foreign_mod: &syn::ItemForeignMod) -> Option { + foreign_mod + .attrs + .iter() + .filter_map(|attr| attr.parse_meta().ok()) + .filter(|meta| meta.path().is_ident(LINK_DIRECTIVE_NAME)) + .filter_map(|meta| { + let pair = match meta { + syn::Meta::List(mut meta_list) if meta_list.nested.len() == 1 => { + meta_list.nested.pop().unwrap() + } + _ => return None, + }; + Some(pair.into_tuple().0) + }) + .filter_map(|nested| match nested { + syn::NestedMeta::Meta(meta) => Some(meta), + _ => None, + }) + .filter(|meta| meta.path().is_ident(WASM_IMPORT_MODULE_DIRECTIVE_NAME)) + .map(extract_value) + .collect() +} - // try to find and parse wasm module name from - // #[link(wasm_import_module = "host")] - let wasm_import_module: Option = self - .attrs - .iter() - .filter_map(|attr| attr.parse_meta().ok()) - .filter(|meta| meta.path().is_ident(LINK_DIRECTIVE_NAME)) - .filter_map(|meta| { - let pair = match meta { - syn::Meta::List(mut meta_list) if meta_list.nested.len() == 1 => { - meta_list.nested.pop().unwrap() - } - _ => return None, - }; - Some(pair.into_tuple().0) - }) - .filter_map(|nested| match nested { - syn::NestedMeta::Meta(meta) => Some(meta), - _ => None, - }) - .filter(|meta| meta.path().is_ident(WASM_IMPORT_MODULE_DIRECTIVE_NAME)) - .map(extract_value) - .collect(); +fn try_extract_namespace( + attr: Option, + foreign_mod: &syn::ItemForeignMod, +) -> Result { + match attr { + Some(namespace) if namespace.is_empty() => syn_error!( + foreign_mod.span(), + "import module name should be defined by 'wasm_import_module' directive" + ), + Some(namespace) => Ok(namespace), + None => syn_error!( + foreign_mod.span(), + "import module name should be defined by 'wasm_import_module' directive" + ), + } +} - match wasm_import_module { - Some(namespace) if namespace.is_empty() => syn_error!( - self_span, - "import module name should be defined by 'wasm_import_module' directive" - ), - Some(namespace) => { - let extern_mod_item = fce_ast_types::AstExternModItem { - namespace, - imports, - original: Some(self), - }; - Ok(FCEAst::ExternMod(extern_mod_item)) +fn extract_import_functions( + foreign_mod: &syn::ItemForeignMod, +) -> Result> { + foreign_mod + .items + .iter() + .cloned() + .map(parse_raw_foreign_item) + .collect::>() +} + +/// This function checks whether these imports contains inner references. In this case glue +/// code couldn't be generated. +fn check_imports<'i>( + extern_fns: impl ExactSizeIterator, +) -> Result<()> { + use super::utils::contain_inner_ref; + + for (extern_fn, span) in extern_fns { + if let Some(output_type) = &extern_fn.signature.output_type { + if contain_inner_ref(output_type) { + return crate::syn_error!( + span, + "import function can't return a value with references" + ); } - None => syn_error!( - self_span, - "import module name should be defined by 'wasm_import_module' directive" - ), } } + + Ok(()) } fn parse_raw_foreign_item(raw_item: syn::ForeignItem) -> Result { diff --git a/crates/wit/src/parse_macro_input/utils.rs b/crates/wit/src/parse_macro_input/utils.rs new file mode 100644 index 0000000..f575d7e --- /dev/null +++ b/crates/wit/src/parse_macro_input/utils.rs @@ -0,0 +1,33 @@ +/* + * Copyright 2020 Fluence Labs Limited + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +use crate::ParsedType; +use crate::parsed_type::PassingStyle; +use crate::parsed_type::passing_style_of; + +/// Checks whether a type contains a reference in one of types. +pub(super) fn contain_inner_ref(ty: &ParsedType) -> bool { + let passing_style = passing_style_of(ty); + match passing_style { + PassingStyle::ByValue => {} + PassingStyle::ByRef | PassingStyle::ByMutRef => return true, + }; + + match ty { + ParsedType::Vector(ty, _) => contain_inner_ref(ty), + _ => false, + } +} diff --git a/crates/wit/src/parsed_type.rs b/crates/wit/src/parsed_type.rs index bc6305c..e819318 100644 --- a/crates/wit/src/parsed_type.rs +++ b/crates/wit/src/parsed_type.rs @@ -57,7 +57,7 @@ pub enum ParsedType { Record(String, PassingStyle), // short type name } -#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)] pub enum PassingStyle { ByValue, ByRef, diff --git a/crates/wit/src/parsed_type/fn_epilog.rs b/crates/wit/src/parsed_type/fn_epilog.rs index 3a16016..7666616 100644 --- a/crates/wit/src/parsed_type/fn_epilog.rs +++ b/crates/wit/src/parsed_type/fn_epilog.rs @@ -116,11 +116,14 @@ fn generate_epilog(ty: &Option) -> proc_macro2::TokenStream { fluence::internal::set_result_size(result.len() as _); } } - Some(ParsedType::Vector(ty, _)) => { + Some(ParsedType::Vector(ty, passing_style)) => { let generated_serializer_name = "__fce_generated_vec_serializer"; let generated_serializer_ident = new_ident!(generated_serializer_name); - let vector_serializer = - super::vector_utils::generate_vector_serializer(ty, generated_serializer_name); + let vector_serializer = super::vector_utils::generate_vector_serializer( + ty, + *passing_style, + generated_serializer_name, + ); quote! { #vector_serializer diff --git a/crates/wit/src/parsed_type/foreign_mod_prolog.rs b/crates/wit/src/parsed_type/foreign_mod_prolog.rs index ce24ebe..7f7ec92 100644 --- a/crates/wit/src/parsed_type/foreign_mod_prolog.rs +++ b/crates/wit/src/parsed_type/foreign_mod_prolog.rs @@ -81,11 +81,11 @@ impl ForeignModPrologGlueCodeGenerator for Vec { arg_transforms.extend(quote::quote! { let mut #arg_ident = std::mem::ManuallyDrop::new(#arg_ident); }); arg_drops.extend(quote::quote! { std::mem::ManuallyDrop::drop(&mut #arg_ident); }); }, - ParsedType::Vector(ty, _) => { + ParsedType::Vector(ty, passing_style) => { let generated_serializer_name = format!("__fce_generated_vec_serializer_{}", arg_name).replace("&<>", "_"); let generated_serializer_ident = new_ident!(generated_serializer_name); - let vector_serializer = super::vector_utils::generate_vector_serializer(ty, &generated_serializer_name); + let vector_serializer = super::vector_utils::generate_vector_serializer(ty, *passing_style, &generated_serializer_name); let arg_transform = quote::quote! { #vector_serializer diff --git a/crates/wit/src/parsed_type/vector_utils.rs b/crates/wit/src/parsed_type/vector_utils.rs index cec1d20..9370dc2 100644 --- a/crates/wit/src/parsed_type/vector_utils.rs +++ b/crates/wit/src/parsed_type/vector_utils.rs @@ -15,21 +15,13 @@ */ use super::ParsedType; +use super::PassingStyle; + use quote::quote; pub(crate) fn generate_vector_serializer( value_ty: &ParsedType, - arg_name: &str, -) -> proc_macro2::TokenStream { - let serializer_func = generate_vector_serializer_impl(value_ty, arg_name); - - quote! { - #serializer_func - } -} - -fn generate_vector_serializer_impl( - value_ty: &ParsedType, + vec_passing_style: PassingStyle, arg_name: &str, ) -> proc_macro2::TokenStream { let values_serializer = match value_ty { @@ -90,12 +82,13 @@ fn generate_vector_serializer_impl( (result.as_ptr() as _, (4 * result.len()) as _) } } - ParsedType::Vector(ty, _) => { + ParsedType::Vector(ty, passing_style) => { let serializer_name = format!("{}_{}", arg_name, ty) .replace("<", "_") .replace(">", "_") .replace("&", "_"); - let inner_vector_serializer = generate_vector_serializer(&*ty, &serializer_name); + let inner_vector_serializer = + generate_vector_serializer(&*ty, *passing_style, &serializer_name); let serializer_ident = crate::new_ident!(serializer_name); quote! { @@ -128,7 +121,7 @@ fn generate_vector_serializer_impl( let arg = crate::new_ident!(arg_name); quote! { - unsafe fn #arg(arg: Vec<#value_ty>) -> (u32, u32) { + unsafe fn #arg(arg: #vec_passing_style Vec<#value_ty>) -> (u32, u32) { #values_serializer } } diff --git a/crates/wit/src/token_stream_generator/record_generator/record_serializer.rs b/crates/wit/src/token_stream_generator/record_generator/record_serializer.rs index 8556e94..ff3d37b 100644 --- a/crates/wit/src/token_stream_generator/record_generator/record_serializer.rs +++ b/crates/wit/src/token_stream_generator/record_generator/record_serializer.rs @@ -44,7 +44,7 @@ impl RecordSerializerGlueCodeGenerator for fce_ast_types::AstRecordItem { std::mem::forget(#field_ident); } } - ParsedType::Vector(ty, _) => { + ParsedType::Vector(ty, passing_style) => { let generated_serializer_name = format!( "__fce_generated_vec_serializer_{}_{}", field.name.as_ref().unwrap(), @@ -54,6 +54,7 @@ impl RecordSerializerGlueCodeGenerator for fce_ast_types::AstRecordItem { let generated_serializer_ident = new_ident!(generated_serializer_name); let vector_serializer = crate::parsed_type::generate_vector_serializer( ty, + *passing_style, &generated_serializer_name, ); let serialized_field_ident = new_ident!(format!("serialized_arg_{}", id)); diff --git a/fluence/tests/import_functions/arrays_out_inner_refs.rs b/fluence/tests/import_functions/arrays_out_inner_refs.rs new file mode 100644 index 0000000..ba6079c --- /dev/null +++ b/fluence/tests/import_functions/arrays_out_inner_refs.rs @@ -0,0 +1,27 @@ +#![allow(improper_ctypes)] + +use fluence::fce; + +pub fn main() {} + +#[fce] +#[link(wasm_import_module = "arrays_passing_effector")] +extern "C" { + #[fce] + pub fn func_1() -> &String; + + #[fce] + pub fn func_2() -> &Vec>>>; + + #[fce] + pub fn func_3() -> Vec<&Vec>>>; + + #[fce] + pub fn func_4() -> Vec>>>; + + #[fce] + pub fn func_5() -> Vec>>>; + + #[fce] + pub fn func_6() -> Vec>>>; +} diff --git a/fluence/tests/import_functions/arrays_out_inner_refs.stderr b/fluence/tests/import_functions/arrays_out_inner_refs.stderr new file mode 100644 index 0000000..07a124c --- /dev/null +++ b/fluence/tests/import_functions/arrays_out_inner_refs.stderr @@ -0,0 +1,6 @@ +error: import function can't return a value with references + --> $DIR/arrays_out_inner_refs.rs:10:5 + | +10 | / #[fce] +11 | | pub fn func_1() -> &String; + | |_______________________________^ diff --git a/fluence/tests/test_runner.rs b/fluence/tests/test_runner.rs index 64e54c0..279bc52 100644 --- a/fluence/tests/test_runner.rs +++ b/fluence/tests/test_runner.rs @@ -8,6 +8,7 @@ fn test() { tests.pass("tests/export_functions/ref_basic_types.rs"); tests.compile_fail("tests/export_functions/improper_types.rs"); + tests.compile_fail("tests/import_functions/arrays_out_inner_refs.rs"); tests.pass("tests/import_functions/arrays.rs"); tests.pass("tests/import_functions/ref_arrays.rs"); tests.pass("tests/import_functions/basic_types.rs");