diff --git a/feature_tests/c/include/CyclicStructB.h b/feature_tests/c/include/CyclicStructB.h index 88b0942b9..e7d8235d9 100644 --- a/feature_tests/c/include/CyclicStructB.h +++ b/feature_tests/c/include/CyclicStructB.h @@ -18,6 +18,9 @@ CyclicStructA CyclicStructB_get_a(void); +typedef struct CyclicStructB_get_a_option_result {union {CyclicStructA ok; }; bool is_ok;} CyclicStructB_get_a_option_result; +CyclicStructB_get_a_option_result CyclicStructB_get_a_option(void); + diff --git a/feature_tests/cpp/include/CyclicStructB.d.hpp b/feature_tests/cpp/include/CyclicStructB.d.hpp index f93dd8d94..a97ca090b 100644 --- a/feature_tests/cpp/include/CyclicStructB.d.hpp +++ b/feature_tests/cpp/include/CyclicStructB.d.hpp @@ -28,6 +28,8 @@ struct CyclicStructB { inline static CyclicStructA get_a(); + inline static std::optional get_a_option(); + inline diplomat::capi::CyclicStructB AsFFI() const; inline static CyclicStructB FromFFI(diplomat::capi::CyclicStructB c_struct); }; diff --git a/feature_tests/cpp/include/CyclicStructB.hpp b/feature_tests/cpp/include/CyclicStructB.hpp index dfc38f135..d9348b155 100644 --- a/feature_tests/cpp/include/CyclicStructB.hpp +++ b/feature_tests/cpp/include/CyclicStructB.hpp @@ -19,6 +19,9 @@ namespace capi { diplomat::capi::CyclicStructA CyclicStructB_get_a(void); + typedef struct CyclicStructB_get_a_option_result {union {diplomat::capi::CyclicStructA ok; }; bool is_ok;} CyclicStructB_get_a_option_result; + CyclicStructB_get_a_option_result CyclicStructB_get_a_option(void); + } // extern "C" } // namespace capi @@ -29,6 +32,11 @@ inline CyclicStructA CyclicStructB::get_a() { return CyclicStructA::FromFFI(result); } +inline std::optional CyclicStructB::get_a_option() { + auto result = diplomat::capi::CyclicStructB_get_a_option(); + return result.is_ok ? std::optional(CyclicStructA::FromFFI(result.ok)) : std::nullopt; +} + inline diplomat::capi::CyclicStructB CyclicStructB::AsFFI() const { return diplomat::capi::CyclicStructB { diff --git a/feature_tests/dart/lib/src/CyclicStructB.g.dart b/feature_tests/dart/lib/src/CyclicStructB.g.dart index 6b8aba647..62ada634e 100644 --- a/feature_tests/dart/lib/src/CyclicStructB.g.dart +++ b/feature_tests/dart/lib/src/CyclicStructB.g.dart @@ -33,6 +33,14 @@ final class CyclicStructB { return CyclicStructA._fromFfi(result); } + static CyclicStructA? getAOption() { + final result = _CyclicStructB_get_a_option(); + if (!result.isOk) { + return null; + } + return CyclicStructA._fromFfi(result.union.ok); + } + @override bool operator ==(Object other) => other is CyclicStructB && @@ -48,3 +56,8 @@ final class CyclicStructB { @ffi.Native<_CyclicStructAFfi Function()>(isLeaf: true, symbol: 'CyclicStructB_get_a') // ignore: non_constant_identifier_names external _CyclicStructAFfi _CyclicStructB_get_a(); + +@meta.RecordUse() +@ffi.Native<_ResultCyclicStructAFfiVoid Function()>(isLeaf: true, symbol: 'CyclicStructB_get_a_option') +// ignore: non_constant_identifier_names +external _ResultCyclicStructAFfiVoid _CyclicStructB_get_a_option(); diff --git a/feature_tests/dart/lib/src/lib.g.dart b/feature_tests/dart/lib/src/lib.g.dart index c00c757cb..fdc4f944a 100644 --- a/feature_tests/dart/lib/src/lib.g.dart +++ b/feature_tests/dart/lib/src/lib.g.dart @@ -118,6 +118,31 @@ class _FinalizedArena { } } +final class _ResultCyclicStructAFfiVoidUnion extends ffi.Union { + external _CyclicStructAFfi ok; + +} + +final class _ResultCyclicStructAFfiVoid extends ffi.Struct { + external _ResultCyclicStructAFfiVoidUnion union; + + @ffi.Bool() + external bool isOk; + + + factory _ResultCyclicStructAFfiVoid.ok(_CyclicStructAFfi val) { + final struct = ffi.Struct.create<_ResultCyclicStructAFfiVoid>(); + struct.isOk = true; + struct.union.ok = val; + return struct; + } + factory _ResultCyclicStructAFfiVoid.err() { + final struct = ffi.Struct.create<_ResultCyclicStructAFfiVoid>(); + struct.isOk = false; + return struct; + } +} + final class _ResultDoubleVoidUnion extends ffi.Union { @ffi.Double() external double ok; diff --git a/feature_tests/js/api/CyclicStructA.mjs b/feature_tests/js/api/CyclicStructA.mjs index 86c30671c..8c60d800f 100644 --- a/feature_tests/js/api/CyclicStructA.mjs +++ b/feature_tests/js/api/CyclicStructA.mjs @@ -49,29 +49,25 @@ export class CyclicStructA { // and passes it down to individual fields containing the borrow. // This method does not attempt to handle any dependencies between lifetimes, the caller // should handle this when constructing edge arrays. - static _fromFFI(internalConstructor, ptr) { + static _fromFFI(internalConstructor, primitiveValue) { if (internalConstructor !== diplomatRuntime.internalConstructor) { throw new Error("CyclicStructA._fromFFI is not meant to be called externally. Please use the default constructor."); } var structObj = {}; - const aDeref = ptr; + const aDeref = primitiveValue; structObj.a = CyclicStructB._fromFFI(diplomatRuntime.internalConstructor, aDeref); return new CyclicStructA(structObj, internalConstructor); } static getB() { - const diplomatReceive = new diplomatRuntime.DiplomatReceiveBuf(wasm, 1, 1, false); - - const result = wasm.CyclicStructA_get_b(diplomatReceive.buffer); + const result = wasm.CyclicStructA_get_b(); try { - return CyclicStructB._fromFFI(diplomatRuntime.internalConstructor, diplomatReceive.buffer); + return CyclicStructB._fromFFI(diplomatRuntime.internalConstructor, result); } - finally { - diplomatReceive.free(); - } + finally {} } cyclicOut() { diff --git a/feature_tests/js/api/CyclicStructB.d.ts b/feature_tests/js/api/CyclicStructB.d.ts index 658eb7636..955c36ea1 100644 --- a/feature_tests/js/api/CyclicStructB.d.ts +++ b/feature_tests/js/api/CyclicStructB.d.ts @@ -13,4 +13,6 @@ export class CyclicStructB { constructor(structObj : CyclicStructB_Obj); static getA(): CyclicStructA; + + static getAOption(): CyclicStructA | null; } \ No newline at end of file diff --git a/feature_tests/js/api/CyclicStructB.mjs b/feature_tests/js/api/CyclicStructB.mjs index 53f4f3a50..2129fa5fd 100644 --- a/feature_tests/js/api/CyclicStructB.mjs +++ b/feature_tests/js/api/CyclicStructB.mjs @@ -49,24 +49,37 @@ export class CyclicStructB { // and passes it down to individual fields containing the borrow. // This method does not attempt to handle any dependencies between lifetimes, the caller // should handle this when constructing edge arrays. - static _fromFFI(internalConstructor, ptr) { + static _fromFFI(internalConstructor, primitiveValue) { if (internalConstructor !== diplomatRuntime.internalConstructor) { throw new Error("CyclicStructB._fromFFI is not meant to be called externally. Please use the default constructor."); } var structObj = {}; - const fieldDeref = (new Uint8Array(wasm.memory.buffer, ptr, 1))[0]; - structObj.field = fieldDeref; + structObj.field = primitiveValue; + return new CyclicStructB(structObj, internalConstructor); } static getA() { - const diplomatReceive = new diplomatRuntime.DiplomatReceiveBuf(wasm, 1, 1, false); + const result = wasm.CyclicStructB_get_a(); + + try { + return CyclicStructA._fromFFI(diplomatRuntime.internalConstructor, result); + } + + finally {} + } + + static getAOption() { + const diplomatReceive = new diplomatRuntime.DiplomatReceiveBuf(wasm, 2, 1, true); - const result = wasm.CyclicStructB_get_a(diplomatReceive.buffer); + const result = wasm.CyclicStructB_get_a_option(diplomatReceive.buffer); try { - return CyclicStructA._fromFFI(diplomatRuntime.internalConstructor, diplomatReceive.buffer); + if (!diplomatReceive.resultFlag) { + return null; + } + return CyclicStructA._fromFFI(diplomatRuntime.internalConstructor, (new Uint8Array(wasm.memory.buffer, diplomatReceive.buffer, 1))[0]); } finally { diff --git a/feature_tests/js/api/CyclicStructC.mjs b/feature_tests/js/api/CyclicStructC.mjs index 561d66a21..5107418d7 100644 --- a/feature_tests/js/api/CyclicStructC.mjs +++ b/feature_tests/js/api/CyclicStructC.mjs @@ -49,12 +49,12 @@ export class CyclicStructC { // and passes it down to individual fields containing the borrow. // This method does not attempt to handle any dependencies between lifetimes, the caller // should handle this when constructing edge arrays. - static _fromFFI(internalConstructor, ptr) { + static _fromFFI(internalConstructor, primitiveValue) { if (internalConstructor !== diplomatRuntime.internalConstructor) { throw new Error("CyclicStructC._fromFFI is not meant to be called externally. Please use the default constructor."); } var structObj = {}; - const aDeref = ptr; + const aDeref = primitiveValue; structObj.a = CyclicStructA._fromFFI(diplomatRuntime.internalConstructor, aDeref); return new CyclicStructC(structObj, internalConstructor); diff --git a/feature_tests/js/test/struct-ts.mjs b/feature_tests/js/test/struct-ts.mjs index 5679066c1..ff9a237bc 100644 --- a/feature_tests/js/test/struct-ts.mjs +++ b/feature_tests/js/test/struct-ts.mjs @@ -1,5 +1,5 @@ import test from 'ava'; -import { MyEnum, MyStruct } from "diplomat-wasm-js-feature-tests"; +import { MyEnum, MyStruct, CyclicStructB } from "diplomat-wasm-js-feature-tests"; test("Verify invariants of struct", t => { const s = MyStruct.new_(); t.is(s.a, 17); @@ -23,3 +23,11 @@ test("Test struct creation", t => { }); t.is(s.intoA(), 17); }); +test("Function Returning Nested Struct of One Field", t => { + const a = CyclicStructB.getA(); + t.is(a.cyclicOut(), "0"); +}); +test("Function De-Referencing Nested Struct of One Primitive", t => { + const a = CyclicStructB.getAOption(); + t.is(a.cyclicOut(), "0"); +}); diff --git a/feature_tests/js/test/struct-ts.mts b/feature_tests/js/test/struct-ts.mts index e6e45c93b..2a0efc56c 100644 --- a/feature_tests/js/test/struct-ts.mts +++ b/feature_tests/js/test/struct-ts.mts @@ -1,5 +1,5 @@ import test from 'ava'; -import { MyEnum, MyStruct } from "diplomat-wasm-js-feature-tests"; +import { MyEnum, MyStruct, CyclicStructB } from "diplomat-wasm-js-feature-tests"; test("Verify invariants of struct", t => { const s = MyStruct.new_(); @@ -24,4 +24,14 @@ test("Test struct creation", t => { g: MyEnum.B }); t.is(s.intoA(), 17); +}); + +test("Function Returning Nested Struct of One Field", t => { + const a = CyclicStructB.getA(); + t.is(a.cyclicOut(), "0"); +}); + +test("Function De-Referencing Nested Struct of One Primitive", t => { + const a = CyclicStructB.getAOption(); + t.is(a.cyclicOut(), "0"); }); \ No newline at end of file diff --git a/feature_tests/js/test/struct.mjs b/feature_tests/js/test/struct.mjs index c8601bc47..31041f43f 100644 --- a/feature_tests/js/test/struct.mjs +++ b/feature_tests/js/test/struct.mjs @@ -1,5 +1,5 @@ import test from 'ava'; -import { MyEnum, MyStruct, ScalarPairWithPadding, BigStructWithStuff } from "diplomat-wasm-js-feature-tests"; +import { MyEnum, MyStruct, ScalarPairWithPadding, BigStructWithStuff, CyclicStructB } from "diplomat-wasm-js-feature-tests"; test("Verify invariants of struct", t => { const s = MyStruct.new_("hello"); @@ -46,3 +46,13 @@ test("Test struct layout: complex struct with multiple padding types and contain s.assertValue(853); t.is(true, true); // Ava doesn't like tests without assertions }); + +test("Function Returning Nested Struct of One Primitive", t => { + const a = CyclicStructB.getA(); + t.is(a.cyclicOut(), "0"); +}); + +test("Function De-Referencing Nested Struct of One Primitive", t => { + const a = CyclicStructB.getAOption(); + t.is(a.cyclicOut(), "0"); +}); \ No newline at end of file diff --git a/feature_tests/kotlin/somelib/src/main/kotlin/dev/diplomattest/somelib/CyclicStructB.kt b/feature_tests/kotlin/somelib/src/main/kotlin/dev/diplomattest/somelib/CyclicStructB.kt index 87bb2a3ff..e3a8b6af0 100644 --- a/feature_tests/kotlin/somelib/src/main/kotlin/dev/diplomattest/somelib/CyclicStructB.kt +++ b/feature_tests/kotlin/somelib/src/main/kotlin/dev/diplomattest/somelib/CyclicStructB.kt @@ -8,6 +8,7 @@ import com.sun.jna.Structure internal interface CyclicStructBLib: Library { fun CyclicStructB_get_a(): CyclicStructANative + fun CyclicStructB_get_a_option(): OptionCyclicStructANative } internal class CyclicStructBNative: Structure(), Structure.ByValue { @@ -36,6 +37,17 @@ class CyclicStructB internal constructor ( val returnStruct = CyclicStructA(returnVal) return returnStruct } + + fun getAOption(): CyclicStructA? { + + val returnVal = lib.CyclicStructB_get_a_option(); + + val intermediateOption = returnVal.option() ?: return null + + val returnStruct = CyclicStructA(intermediateOption) + return returnStruct + + } } } diff --git a/feature_tests/kotlin/somelib/src/main/kotlin/dev/diplomattest/somelib/Lib.kt b/feature_tests/kotlin/somelib/src/main/kotlin/dev/diplomattest/somelib/Lib.kt index 04522f675..8f5c948a5 100644 --- a/feature_tests/kotlin/somelib/src/main/kotlin/dev/diplomattest/somelib/Lib.kt +++ b/feature_tests/kotlin/somelib/src/main/kotlin/dev/diplomattest/somelib/Lib.kt @@ -495,6 +495,26 @@ internal class OptionByte: Structure(), Structure.ByValue { } } } +internal class OptionCyclicStructANative: Structure(), Structure.ByValue { + @JvmField + internal var value: CyclicStructANative = CyclicStructANative() + + @JvmField + internal var isOk: Byte = 0 + + // Define the fields of the struct + override fun getFieldOrder(): List { + return listOf("value", "isOk") + } + + internal fun option(): CyclicStructANative? { + if (isOk == 1.toByte()) { + return value + } else { + return null + } + } +} internal class OptionDouble: Structure(), Structure.ByValue { @JvmField internal var value: Double = 0.0 diff --git a/feature_tests/src/structs.rs b/feature_tests/src/structs.rs index 8cb1515a8..66039ca39 100644 --- a/feature_tests/src/structs.rs +++ b/feature_tests/src/structs.rs @@ -228,6 +228,10 @@ pub mod ffi { pub fn get_a() -> CyclicStructA { Default::default() } + + pub fn get_a_option() -> Option { + Some(Default::default()) + } } impl CyclicStructC { diff --git a/tool/src/js/converter.rs b/tool/src/js/converter.rs index 616d59262..39e0adae4 100644 --- a/tool/src/js/converter.rs +++ b/tool/src/js/converter.rs @@ -284,25 +284,48 @@ impl<'tcx> TyGenContext<'_, 'tcx> { offset: usize, ) -> Cow<'tcx, str> { let pointer = if offset == 0 { - variable_name + variable_name.clone() } else { format!("{variable_name} + {offset}").into() }; - match *ty { + match ty { Type::Enum(..) => format!("diplomatRuntime.enumDiscriminant(wasm, {pointer})").into(), Type::Opaque(..) => format!("diplomatRuntime.ptrRead(wasm, {pointer})").into(), - // Structs always assume they're being passed a pointer, so they handle this in their constructors: - // See NestedBorrowedFields - Type::Struct(..) | Type::Slice(..) | Type::DiplomatOption(..) => pointer, Type::Primitive(p) => format!( "(new {ctor}(wasm.memory.buffer, {pointer}, 1))[0]{cmp}", - ctor = self.formatter.fmt_primitive_slice(p), + ctor = self.formatter.fmt_primitive_slice(*p), cmp = match p { PrimitiveType::Bool => " === 1", _ => "", } ) .into(), + Type::Struct(st) + if match st.id() { + hir::TypeId::OutStruct(s) => { + self.only_primitive(self.tcx.resolve_out_struct(s)) + } + hir::TypeId::Struct(s) => self.only_primitive(self.tcx.resolve_struct(s)), + _ => false, + } => + { + match st.id() { + hir::TypeId::OutStruct(s) => { + let first = self.tcx.resolve_out_struct(s).fields.first().unwrap(); + + self.gen_c_to_js_deref_for_type(&first.ty, variable_name, offset) + } + hir::TypeId::Struct(s) => { + let first = self.tcx.resolve_struct(s).fields.first().unwrap(); + + self.gen_c_to_js_deref_for_type(&first.ty, variable_name, offset) + } + _ => unreachable!("Expected struct, got {:?}", st.id()), + } + } + // Structs (nearly) always assume they're being passed a pointer, so they handle this in their constructors: + // See NestedBorrowedFields + Type::Struct(..) | Type::Slice(..) | Type::DiplomatOption(..) => pointer, _ => unreachable!("Unknown AST/HIR variant {:?}", ty), } } @@ -381,6 +404,7 @@ impl<'tcx> TyGenContext<'_, 'tcx> { ReturnType::Infallible(SuccessType::OutType(ref o)) => { let mut result = "result"; match o { + Type::Struct(s) if self.wraps_a_primitive(s) => {} Type::Struct(_) | Type::Slice(_) => { let layout = crate::js::layout::type_size_alignment(o, self.tcx); let size = layout.size(); diff --git a/tool/src/js/gen.rs b/tool/src/js/gen.rs index 12413817d..52bf7fd05 100644 --- a/tool/src/js/gen.rs +++ b/tool/src/js/gen.rs @@ -10,8 +10,8 @@ use diplomat_core::hir::borrowing_param::{ BorrowedLifetimeInfo, LifetimeEdge, LifetimeEdgeKind, ParamBorrowInfo, StructBorrowInfo, }; use diplomat_core::hir::{ - self, EnumDef, LifetimeEnv, Method, OpaqueDef, SpecialMethod, SpecialMethodPresence, Type, - TypeContext, TypeId, + self, EnumDef, LifetimeEnv, Method, OpaqueDef, SpecialMethod, SpecialMethodPresence, + StructPathLike, Type, TypeContext, TypeId, }; use askama::{self, Template}; @@ -293,6 +293,34 @@ impl<'tcx> TyGenContext<'_, 'tcx> { (fields, needs_force_padding) } + pub(super) fn only_primitive(&self, st: &hir::StructDef

) -> bool { + if st.fields.len() != 1 { + return false; + } + + let first = st.fields.first().unwrap(); + + match &first.ty { + hir::Type::Primitive(..) => true, + hir::Type::Struct(s) => match s.id() { + hir::TypeId::Struct(s) => self.only_primitive(self.tcx.resolve_struct(s)), + hir::TypeId::OutStruct(s) => self.only_primitive(self.tcx.resolve_out_struct(s)), + _ => false, + }, + _ => false, + } + } + + /// WASM only returns a primitive (instead of a pointer) if our struct just wraps a primitive (or nests a struct that only has one primitive as a field). + /// This is a quick way to verify that we are grabbing a value instead of a pointer. + pub(super) fn wraps_a_primitive(&self, st: &hir::ReturnableStructPath) -> bool { + match st.resolve(self.tcx) { + hir::ReturnableStructDef::OutStruct(s) => self.only_primitive(s), + hir::ReturnableStructDef::Struct(s) => self.only_primitive(s), + _ => false, + } + } + /// Generate a struct type's body for a file from the given definition. /// /// Used for both [`hir::TypeDef::Struct`] and [`hir::TypeDef::OutStruct`], which is why `is_out` exists. @@ -321,6 +349,9 @@ impl<'tcx> TyGenContext<'_, 'tcx> { fields: &'a Vec>, methods: &'a MethodsInfo<'a>, + wraps_primitive: bool, + owns_wrapped_primitive: bool, + docs: String, } @@ -336,6 +367,13 @@ impl<'tcx> TyGenContext<'_, 'tcx> { fields, methods, + wraps_primitive: self.only_primitive(struct_def), + owns_wrapped_primitive: !struct_def.fields.is_empty() + && matches!( + struct_def.fields.first().unwrap().ty, + hir::Type::Primitive(..) + ), + docs: self.formatter.fmt_docs(&struct_def.docs), } .render() diff --git a/tool/templates/js/struct.js.jinja b/tool/templates/js/struct.js.jinja index 636225de5..b79a77af7 100644 --- a/tool/templates/js/struct.js.jinja +++ b/tool/templates/js/struct.js.jinja @@ -112,15 +112,20 @@ export class {{type_name}} { // This method does not attempt to handle any dependencies between lifetimes, the caller // should handle this when constructing edge arrays. {% endif -%} - static _fromFFI(internalConstructor, ptr{%- for lifetime in lifetimes.all_lifetimes() -%}, {{lifetimes.fmt_lifetime(lifetime)}}Edges{%- endfor -%}) { + static _fromFFI(internalConstructor, {%- if !wraps_primitive %} ptr {%- else %} primitiveValue {%- endif -%} {%- for lifetime in lifetimes.all_lifetimes() -%}, {{lifetimes.fmt_lifetime(lifetime)}}Edges{%- endfor -%}) { if (internalConstructor !== diplomatRuntime.internalConstructor) { throw new Error("{{type_name}}._fromFFI is not meant to be called externally. Please use the default constructor."); } var structObj = {}; + + {%- if wraps_primitive && owns_wrapped_primitive %} + structObj.{{fields.first().unwrap().field_name}} = primitiveValue; + {% else %} {%- for field in fields %} - const {{field.field_name}}Deref = {{field.c_to_js_deref}}; + const {{field.field_name}}Deref = {%- if wraps_primitive && !owns_wrapped_primitive %} primitiveValue {%- else %} {{field.c_to_js_deref}} {%- endif %}; structObj.{{field.field_name}} = {{field.c_to_js}}; {%- endfor %} + {%- endif %} return new {{type_name}}(structObj, internalConstructor); }