diff --git a/liveview_native/live_nx_iree/config/runtime.exs b/liveview_native/live_nx_iree/config/runtime.exs
index e6043a5..73a14f6 100644
--- a/liveview_native/live_nx_iree/config/runtime.exs
+++ b/liveview_native/live_nx_iree/config/runtime.exs
@@ -20,6 +20,8 @@ if System.get_env("PHX_SERVER") do
config :live_nx_iree, LiveNxIREEWeb.Endpoint, server: true
end
+config :nx, :default_backend, NxIREE.Tensor
+
if config_env() == :prod do
database_url =
System.get_env("DATABASE_URL") ||
diff --git a/liveview_native/live_nx_iree/lib/live_nx_iree_web/live/home_live/home_live.ex b/liveview_native/live_nx_iree/lib/live_nx_iree_web/live/home_live/home_live.ex
index 018233f..670d244 100644
--- a/liveview_native/live_nx_iree/lib/live_nx_iree_web/live/home_live/home_live.ex
+++ b/liveview_native/live_nx_iree/lib/live_nx_iree_web/live/home_live/home_live.ex
@@ -15,7 +15,14 @@ defmodule LiveNxIREEWeb.HomeLive do
dbg(self())
- socket = assign(socket, bytecode: nil, function_signature: nil, device: nil)
+ socket =
+ assign(socket,
+ bytecode: nil,
+ function_signature: nil,
+ device: nil,
+ inputs: nil,
+ num_outputs: nil
+ )
{:ok, socket}
end
@@ -43,7 +50,7 @@ defmodule LiveNxIREEWeb.HomeLive do
# end
@impl true
- def handle_info({:nx, :execute, function, input_templates, target_device, reply_to_pid}, socket) do
+ def handle_info({:nx, :execute, function, inputs, target_device, reply_to_pid}, socket) do
fun =
case function do
{m, f, a} ->
@@ -71,19 +78,36 @@ defmodule LiveNxIREEWeb.HomeLive do
]
{:ok, %{bytecode: %NxIREE.Module{bytecode: bytecode}, output_container: output_container}} =
- NxIREE.Compiler.to_bytecode(fun, input_templates, iree_compiler_flags: compiler_flags)
+ NxIREE.Compiler.to_bytecode(fun, inputs, iree_compiler_flags: compiler_flags)
+
+ {_, num_outputs} =
+ Nx.Defn.Composite.traverse(output_container, 0, fn node, acc -> {node, acc + 1} end)
socket =
socket
|> assign(:bytecode, Base.encode64(bytecode))
|> assign(:output_container, output_container)
- |> assign(:function_signature, get_signature(function, input_templates, output_container))
+ |> assign(:function_signature, get_signature(function, inputs, output_container))
|> assign(:device, runtime_device)
|> assign(:reply_to_pid, reply_to_pid)
+ |> assign(:inputs, serialize_inputs(inputs))
+ |> assign(:num_outputs, num_outputs)
{:noreply, socket}
end
+ defp serialize_inputs(inputs) do
+ List.wrap(inputs)
+ |> Nx.Defn.Composite.flatten_list()
+ |> Enum.map(fn tensor ->
+ tensor = Nx.to_tensor(tensor)
+
+ {:ok, serialized} = NxIREE.Native.serialize_tensor(tensor.data.ref)
+
+ serialized
+ end)
+ end
+
@impl true
def handle_event("nx-executed", params, socket) do
send(socket.assigns.reply_to_pid, {:nx, :executed, params})
diff --git a/liveview_native/live_nx_iree/lib/live_nx_iree_web/live/swiftui/home_live.swiftui.neex b/liveview_native/live_nx_iree/lib/live_nx_iree_web/live/swiftui/home_live.swiftui.neex
index a60b504..7a4b363 100644
--- a/liveview_native/live_nx_iree/lib/live_nx_iree_web/live/swiftui/home_live.swiftui.neex
+++ b/liveview_native/live_nx_iree/lib/live_nx_iree_web/live/swiftui/home_live.swiftui.neex
@@ -1 +1 @@
-
\ No newline at end of file
+
\ No newline at end of file
diff --git a/liveview_native/live_nx_iree/native/swiftui/LiveNxIREE.xcodeproj/project.xcworkspace/xcuserdata/paulo.valente.xcuserdatad/UserInterfaceState.xcuserstate b/liveview_native/live_nx_iree/native/swiftui/LiveNxIREE.xcodeproj/project.xcworkspace/xcuserdata/paulo.valente.xcuserdatad/UserInterfaceState.xcuserstate
index 61553c6..148045a 100644
Binary files a/liveview_native/live_nx_iree/native/swiftui/LiveNxIREE.xcodeproj/project.xcworkspace/xcuserdata/paulo.valente.xcuserdatad/UserInterfaceState.xcuserstate and b/liveview_native/live_nx_iree/native/swiftui/LiveNxIREE.xcodeproj/project.xcworkspace/xcuserdata/paulo.valente.xcuserdatad/UserInterfaceState.xcuserstate differ
diff --git a/liveview_native/live_nx_iree/native/swiftui/LiveNxIREE/NxAddon.swift b/liveview_native/live_nx_iree/native/swiftui/LiveNxIREE/NxAddon.swift
index 608059d..3cfc91a 100644
--- a/liveview_native/live_nx_iree/native/swiftui/LiveNxIREE/NxAddon.swift
+++ b/liveview_native/live_nx_iree/native/swiftui/LiveNxIREE/NxAddon.swift
@@ -24,7 +24,7 @@ func nx_iree_initialize(
@_silgen_name("nx_iree_create_device")
func nx_iree_create_device(
_ driver_registry: UnsafeMutablePointer,
- _ name: UnsafePointer) -> UnsafeMutablePointer
+ _ name: UnsafePointer) -> UnsafeMutablePointer
@_silgen_name("nx_iree_call")
func nx_iree_call(
diff --git a/liveview_native/live_nx_iree/native/swiftui/LiveNxIREE/NxAddon/NxFunctionView.swift b/liveview_native/live_nx_iree/native/swiftui/LiveNxIREE/NxAddon/NxFunctionView.swift
index 40ebeae..d838f74 100644
--- a/liveview_native/live_nx_iree/native/swiftui/LiveNxIREE/NxAddon/NxFunctionView.swift
+++ b/liveview_native/live_nx_iree/native/swiftui/LiveNxIREE/NxAddon/NxFunctionView.swift
@@ -14,7 +14,8 @@ struct NxFunctionView: View {
@LiveAttribute("bytecode") private var bytecode: String? = nil
@LiveAttribute("signature") private var signature: String? = nil
@LiveAttribute("device") private var deviceURI: String? = nil
- @LiveAttribute("trigger") private var trigger: Bool = false
+ @LiveAttribute("inputs") private var serializedInputs: [String]? = nil
+ @LiveAttribute("num-outputs") private var numOutputs: Int? = nil
@Event("on-execution", type: "change") private var change
var body: some View {
@@ -32,18 +33,125 @@ struct NxFunctionView: View {
}
}
- private func run() {
- if bytecode == nil {
- return
+ private func convertBase64StringToBytecode(_ base64String: String) -> (bytecodeSize: UInt64, bytecodePointer: UnsafePointer?)? {
+ // Step 1: Decode the Base64 string into Data
+ guard let decodedData = Data(base64Encoded: base64String) else {
+ print("Failed to decode base64 string.")
+ return nil
+ }
+
+ // Step 2: Get the size of the data
+ let bytecodeSize = UInt64(decodedData.count)
+
+ // Step 3: Convert Data to UnsafePointer
+ // We use `withUnsafeBytes` to get a pointer to the data
+ let bytecodePointer = decodedData.withUnsafeBytes { (pointer: UnsafeRawBufferPointer) -> UnsafePointer? in
+ return pointer.bindMemory(to: CUnsignedChar.self).baseAddress
+ }
+
+ return (bytecodeSize, bytecodePointer)
+ }
+
+ private func convertToCStringArray(from strings: [String]) -> UnsafePointer>? {
+ // Array to hold the C strings (UnsafePointer)
+ var cStrings: [UnsafePointer] = []
+
+ for string in strings {
+ // Decode the base64 string to Data
+ guard let decodedData = Data(base64Encoded: string) else {
+ print("Failed to decode base64 string: \(string)")
+ return nil
+ }
+
+ // Convert Data to a C string (null-terminated UTF-8)
+ let cString = decodedData.withUnsafeBytes { (pointer: UnsafeRawBufferPointer) -> UnsafePointer? in
+ guard let baseAddress = pointer.baseAddress else { return nil }
+ // Allocate memory for the C string and copy the data
+ let cStringPointer = UnsafeMutablePointer.allocate(capacity: decodedData.count + 1)
+ cStringPointer.initialize(from: baseAddress.assumingMemoryBound(to: CChar.self), count: decodedData.count)
+ cStringPointer[decodedData.count] = 0 // Null-terminate the string
+ return UnsafePointer(cStringPointer)
+ }
+
+ guard let cStr = cString else {
+ print("Failed to convert Data to C string.")
+ return nil
+ }
+
+ cStrings.append(cStr)
}
- if let vmInstance = globalVmInstance,
- let driverRegistry = globalDriverRegistry {
- print("Executing function \(signature ?? "None") on device: \(deviceURI ?? "None")")
- change(value: "Sending something back")
+ // Allocate memory for the array of C strings
+ let cStringsPointer = UnsafeMutablePointer>.allocate(capacity: cStrings.count)
+
+ // Copy the C strings to the allocated array
+ cStringsPointer.initialize(from: &cStrings, count: cStrings.count)
+
+ // Return the pointer to the array
+ return UnsafePointer(cStringsPointer)
+ }
+
+ private func base64EncodedStrings(from serializedOutputs: UnsafePointer>, count: Int) -> [String] {
+ // Convert UnsafePointer to a Swift array of UnsafePointer
+ let cStringPointers = Array(UnsafeBufferPointer(start: serializedOutputs, count: count))
+
+ var base64Strings: [String] = []
+
+ for cStringPointer in cStringPointers {
+ // Convert each C string to a Swift String
+ let string = String(cString: cStringPointer)
+
+ // Encode the string to Base64
+ if let data = string.data(using: .utf8) {
+ let base64String = data.base64EncodedString()
+ base64Strings.append(base64String)
+ }
+ }
+
+ return base64Strings
+ }
+
+
+ private func run() {
+ if bytecode != nil,
+ deviceURI != nil,
+ globalVmInstance != nil,
+ globalDriverRegistry != nil,
+ serializedInputs != nil,
+ let (bytecodeSize, bytecodePointer) = convertBase64StringToBytecode(bytecode!),
+ let inputs = convertToCStringArray(from: serializedInputs!) {
+ let deviceURIcstr = strdup(deviceURI!)
+ let device = nx_iree_create_device(globalDriverRegistry!, UnsafePointer(deviceURIcstr)!)
+ deviceURIcstr?.deallocate()
+
+ let serializedOutputs: UnsafePointer>? = nil
+ let errorMessage = UnsafeMutablePointer.allocate(capacity: 256)
+
+ print("Executing function \(signature ?? "None") on device: \(deviceURI ?? "None")")
+
+ let result = nx_iree_call(
+ globalVmInstance!,
+ device,
+ bytecodeSize,
+ bytecodePointer!,
+ UInt64(serializedInputs!.count),
+ inputs,
+ UInt64(numOutputs!),
+ serializedOutputs!,
+ errorMessage)
+
+ if result != 0 {
+ print(errorMessage)
+ return
+ }
+
+ for i in 0..