diff --git a/liveview_native/live_nx_iree/config/runtime.exs b/liveview_native/live_nx_iree/config/runtime.exs index e6043a5..73a14f6 100644 --- a/liveview_native/live_nx_iree/config/runtime.exs +++ b/liveview_native/live_nx_iree/config/runtime.exs @@ -20,6 +20,8 @@ if System.get_env("PHX_SERVER") do config :live_nx_iree, LiveNxIREEWeb.Endpoint, server: true end +config :nx, :default_backend, NxIREE.Tensor + if config_env() == :prod do database_url = System.get_env("DATABASE_URL") || diff --git a/liveview_native/live_nx_iree/lib/live_nx_iree_web/live/home_live/home_live.ex b/liveview_native/live_nx_iree/lib/live_nx_iree_web/live/home_live/home_live.ex index 018233f..670d244 100644 --- a/liveview_native/live_nx_iree/lib/live_nx_iree_web/live/home_live/home_live.ex +++ b/liveview_native/live_nx_iree/lib/live_nx_iree_web/live/home_live/home_live.ex @@ -15,7 +15,14 @@ defmodule LiveNxIREEWeb.HomeLive do dbg(self()) - socket = assign(socket, bytecode: nil, function_signature: nil, device: nil) + socket = + assign(socket, + bytecode: nil, + function_signature: nil, + device: nil, + inputs: nil, + num_outputs: nil + ) {:ok, socket} end @@ -43,7 +50,7 @@ defmodule LiveNxIREEWeb.HomeLive do # end @impl true - def handle_info({:nx, :execute, function, input_templates, target_device, reply_to_pid}, socket) do + def handle_info({:nx, :execute, function, inputs, target_device, reply_to_pid}, socket) do fun = case function do {m, f, a} -> @@ -71,19 +78,36 @@ defmodule LiveNxIREEWeb.HomeLive do ] {:ok, %{bytecode: %NxIREE.Module{bytecode: bytecode}, output_container: output_container}} = - NxIREE.Compiler.to_bytecode(fun, input_templates, iree_compiler_flags: compiler_flags) + NxIREE.Compiler.to_bytecode(fun, inputs, iree_compiler_flags: compiler_flags) + + {_, num_outputs} = + Nx.Defn.Composite.traverse(output_container, 0, fn node, acc -> {node, acc + 1} end) socket = socket |> assign(:bytecode, Base.encode64(bytecode)) |> assign(:output_container, output_container) - |> assign(:function_signature, get_signature(function, input_templates, output_container)) + |> assign(:function_signature, get_signature(function, inputs, output_container)) |> assign(:device, runtime_device) |> assign(:reply_to_pid, reply_to_pid) + |> assign(:inputs, serialize_inputs(inputs)) + |> assign(:num_outputs, num_outputs) {:noreply, socket} end + defp serialize_inputs(inputs) do + List.wrap(inputs) + |> Nx.Defn.Composite.flatten_list() + |> Enum.map(fn tensor -> + tensor = Nx.to_tensor(tensor) + + {:ok, serialized} = NxIREE.Native.serialize_tensor(tensor.data.ref) + + serialized + end) + end + @impl true def handle_event("nx-executed", params, socket) do send(socket.assigns.reply_to_pid, {:nx, :executed, params}) diff --git a/liveview_native/live_nx_iree/lib/live_nx_iree_web/live/swiftui/home_live.swiftui.neex b/liveview_native/live_nx_iree/lib/live_nx_iree_web/live/swiftui/home_live.swiftui.neex index a60b504..7a4b363 100644 --- a/liveview_native/live_nx_iree/lib/live_nx_iree_web/live/swiftui/home_live.swiftui.neex +++ b/liveview_native/live_nx_iree/lib/live_nx_iree_web/live/swiftui/home_live.swiftui.neex @@ -1 +1 @@ - \ No newline at end of file + \ No newline at end of file diff --git a/liveview_native/live_nx_iree/native/swiftui/LiveNxIREE.xcodeproj/project.xcworkspace/xcuserdata/paulo.valente.xcuserdatad/UserInterfaceState.xcuserstate b/liveview_native/live_nx_iree/native/swiftui/LiveNxIREE.xcodeproj/project.xcworkspace/xcuserdata/paulo.valente.xcuserdatad/UserInterfaceState.xcuserstate index 61553c6..148045a 100644 Binary files a/liveview_native/live_nx_iree/native/swiftui/LiveNxIREE.xcodeproj/project.xcworkspace/xcuserdata/paulo.valente.xcuserdatad/UserInterfaceState.xcuserstate and b/liveview_native/live_nx_iree/native/swiftui/LiveNxIREE.xcodeproj/project.xcworkspace/xcuserdata/paulo.valente.xcuserdatad/UserInterfaceState.xcuserstate differ diff --git a/liveview_native/live_nx_iree/native/swiftui/LiveNxIREE/NxAddon.swift b/liveview_native/live_nx_iree/native/swiftui/LiveNxIREE/NxAddon.swift index 608059d..3cfc91a 100644 --- a/liveview_native/live_nx_iree/native/swiftui/LiveNxIREE/NxAddon.swift +++ b/liveview_native/live_nx_iree/native/swiftui/LiveNxIREE/NxAddon.swift @@ -24,7 +24,7 @@ func nx_iree_initialize( @_silgen_name("nx_iree_create_device") func nx_iree_create_device( _ driver_registry: UnsafeMutablePointer, - _ name: UnsafePointer) -> UnsafeMutablePointer + _ name: UnsafePointer) -> UnsafeMutablePointer @_silgen_name("nx_iree_call") func nx_iree_call( diff --git a/liveview_native/live_nx_iree/native/swiftui/LiveNxIREE/NxAddon/NxFunctionView.swift b/liveview_native/live_nx_iree/native/swiftui/LiveNxIREE/NxAddon/NxFunctionView.swift index 40ebeae..d838f74 100644 --- a/liveview_native/live_nx_iree/native/swiftui/LiveNxIREE/NxAddon/NxFunctionView.swift +++ b/liveview_native/live_nx_iree/native/swiftui/LiveNxIREE/NxAddon/NxFunctionView.swift @@ -14,7 +14,8 @@ struct NxFunctionView: View { @LiveAttribute("bytecode") private var bytecode: String? = nil @LiveAttribute("signature") private var signature: String? = nil @LiveAttribute("device") private var deviceURI: String? = nil - @LiveAttribute("trigger") private var trigger: Bool = false + @LiveAttribute("inputs") private var serializedInputs: [String]? = nil + @LiveAttribute("num-outputs") private var numOutputs: Int? = nil @Event("on-execution", type: "change") private var change var body: some View { @@ -32,18 +33,125 @@ struct NxFunctionView: View { } } - private func run() { - if bytecode == nil { - return + private func convertBase64StringToBytecode(_ base64String: String) -> (bytecodeSize: UInt64, bytecodePointer: UnsafePointer?)? { + // Step 1: Decode the Base64 string into Data + guard let decodedData = Data(base64Encoded: base64String) else { + print("Failed to decode base64 string.") + return nil + } + + // Step 2: Get the size of the data + let bytecodeSize = UInt64(decodedData.count) + + // Step 3: Convert Data to UnsafePointer + // We use `withUnsafeBytes` to get a pointer to the data + let bytecodePointer = decodedData.withUnsafeBytes { (pointer: UnsafeRawBufferPointer) -> UnsafePointer? in + return pointer.bindMemory(to: CUnsignedChar.self).baseAddress + } + + return (bytecodeSize, bytecodePointer) + } + + private func convertToCStringArray(from strings: [String]) -> UnsafePointer>? { + // Array to hold the C strings (UnsafePointer) + var cStrings: [UnsafePointer] = [] + + for string in strings { + // Decode the base64 string to Data + guard let decodedData = Data(base64Encoded: string) else { + print("Failed to decode base64 string: \(string)") + return nil + } + + // Convert Data to a C string (null-terminated UTF-8) + let cString = decodedData.withUnsafeBytes { (pointer: UnsafeRawBufferPointer) -> UnsafePointer? in + guard let baseAddress = pointer.baseAddress else { return nil } + // Allocate memory for the C string and copy the data + let cStringPointer = UnsafeMutablePointer.allocate(capacity: decodedData.count + 1) + cStringPointer.initialize(from: baseAddress.assumingMemoryBound(to: CChar.self), count: decodedData.count) + cStringPointer[decodedData.count] = 0 // Null-terminate the string + return UnsafePointer(cStringPointer) + } + + guard let cStr = cString else { + print("Failed to convert Data to C string.") + return nil + } + + cStrings.append(cStr) } - if let vmInstance = globalVmInstance, - let driverRegistry = globalDriverRegistry { - print("Executing function \(signature ?? "None") on device: \(deviceURI ?? "None")") - change(value: "Sending something back") + // Allocate memory for the array of C strings + let cStringsPointer = UnsafeMutablePointer>.allocate(capacity: cStrings.count) + + // Copy the C strings to the allocated array + cStringsPointer.initialize(from: &cStrings, count: cStrings.count) + + // Return the pointer to the array + return UnsafePointer(cStringsPointer) + } + + private func base64EncodedStrings(from serializedOutputs: UnsafePointer>, count: Int) -> [String] { + // Convert UnsafePointer to a Swift array of UnsafePointer + let cStringPointers = Array(UnsafeBufferPointer(start: serializedOutputs, count: count)) + + var base64Strings: [String] = [] + + for cStringPointer in cStringPointers { + // Convert each C string to a Swift String + let string = String(cString: cStringPointer) + + // Encode the string to Base64 + if let data = string.data(using: .utf8) { + let base64String = data.base64EncodedString() + base64Strings.append(base64String) + } + } + + return base64Strings + } + + + private func run() { + if bytecode != nil, + deviceURI != nil, + globalVmInstance != nil, + globalDriverRegistry != nil, + serializedInputs != nil, + let (bytecodeSize, bytecodePointer) = convertBase64StringToBytecode(bytecode!), + let inputs = convertToCStringArray(from: serializedInputs!) { + let deviceURIcstr = strdup(deviceURI!) + let device = nx_iree_create_device(globalDriverRegistry!, UnsafePointer(deviceURIcstr)!) + deviceURIcstr?.deallocate() + + let serializedOutputs: UnsafePointer>? = nil + let errorMessage = UnsafeMutablePointer.allocate(capacity: 256) + + print("Executing function \(signature ?? "None") on device: \(deviceURI ?? "None")") + + let result = nx_iree_call( + globalVmInstance!, + device, + bytecodeSize, + bytecodePointer!, + UInt64(serializedInputs!.count), + inputs, + UInt64(numOutputs!), + serializedOutputs!, + errorMessage) + + if result != 0 { + print(errorMessage) + return + } + + for i in 0..