diff --git a/native/wasmex/Cargo.lock b/native/wasmex/Cargo.lock index 13c2cc6..375ab30 100644 --- a/native/wasmex/Cargo.lock +++ b/native/wasmex/Cargo.lock @@ -1437,6 +1437,7 @@ dependencies = [ "num-bigint", "regex-lite", "rustler_codegen", + "serde", ] [[package]] diff --git a/native/wasmex/Cargo.toml b/native/wasmex/Cargo.toml index 2f9fb63..ae0a3a9 100644 --- a/native/wasmex/Cargo.toml +++ b/native/wasmex/Cargo.toml @@ -15,7 +15,7 @@ path = "src/lib.rs" crate-type = ["dylib"] [dependencies] -rustler = { version = "0.35", features = ["big_integer"] } +rustler = { version = "0.35", features = ["big_integer", "serde"] } once_cell = "1.20.2" rand = "0.8.5" wasmtime = "26.0.1" diff --git a/native/wasmex/src/component_instance.rs b/native/wasmex/src/component_instance.rs index ce3d5a5..36e0429 100644 --- a/native/wasmex/src/component_instance.rs +++ b/native/wasmex/src/component_instance.rs @@ -117,7 +117,7 @@ fn term_to_val(param_term: &Term, param_type: &Type) -> Result { let decoded_map = param_term.decode::>()?; let terms = decoded_map .iter() - .map(|(key, val)| (key.decode::().unwrap(), val)) + .map(|(key_term, val)| (term_to_field_name(key_term), val)) .collect::>(); for field in record.fields() { let field_term_option = terms.iter().find(|(k, _)| k == field.name); @@ -148,6 +148,17 @@ fn term_to_val(param_term: &Term, param_type: &Type) -> Result { } } +fn term_to_field_name(key_term: &Term) -> String { + match key_term.get_type() { + TermType::Atom => key_term.atom_to_string().unwrap(), + _ => key_term.decode::().unwrap(), + } +} + +fn field_name_to_term<'a>(env: &rustler::Env<'a>, field_name: &str) -> Term<'a> { + rustler::serde::atoms::str_to_term(env, field_name).unwrap() +} + fn encode_result(env: rustler::Env, vals: Vec) -> Term { let result_term = match vals.len() { 1 => val_to_term(vals.first().unwrap(), env), @@ -182,8 +193,8 @@ fn val_to_term<'a>(val: &Val, env: rustler::Env<'a>) -> Term<'a> { Val::Record(record) => { let converted_pairs = record .iter() - .map(|(key, val)| (key, val_to_term(val, env))) - .collect::)>>(); + .map(|(key, val)| (field_name_to_term(&env, key), val_to_term(val, env))) + .collect::>(); Term::map_from_pairs(env, converted_pairs.as_slice()).unwrap() } Val::Tuple(tuple) => { diff --git a/test/components/component_types_test.exs b/test/components/component_types_test.exs index 79b3ab6..ee06a78 100644 --- a/test/components/component_types_test.exs +++ b/test/components/component_types_test.exs @@ -33,12 +33,16 @@ defmodule Wasm.Components.ComponentTypesTest do end test "records", %{instance: instance} do - # don't love this yet, be nicer to support atom keys - assert {:ok, %{"x" => 1, "y" => 2}} = + assert {:ok, %{x: 1, y: 2}} = Wasmex.Components.Instance.call_function(instance, "id-record", [ %{"x" => 1, "y" => 2} ]) + assert {:ok, %{x: 1, y: 2}} = + Wasmex.Components.Instance.call_function(instance, "id-record", [ + %{x: 1, y: 2} + ]) + assert {:error, _error} = Wasmex.Components.Instance.call_function(instance, "id-record", [ %{"invalid-field" => "foo"}