Skip to content

Commit

Permalink
Allow multiple instances of WasmInterpreter to have imported functions (
Browse files Browse the repository at this point in the history
#14)

This commit fixes a bug in which deinitializing an instance of `WasmInterpreter` would remove all cached implementations for functions imported by any `WasmInterpreter` instance, not just the one being deinitialized.

Now, we generate a unique identifier for each instance of `WasmInterpreter`, and we use that identifier as the key to its cached imported functions.

This pull request also increases to 10 the maximum number of Wasm functions that can share the same implementation in the same module. Previously, the limit was 3.
  • Loading branch information
atdrendel authored Aug 3, 2021
1 parent 3611d65 commit 3dab007
Show file tree
Hide file tree
Showing 7 changed files with 147 additions and 83 deletions.
4 changes: 2 additions & 2 deletions Package.resolved
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
"repositoryURL": "https://github.com/shareup/cwasm3.git",
"state": {
"branch": null,
"revision": "1e7c4db769e0638d3c0833a941ffb475100a5c0d",
"version": "0.5.0"
"revision": "38335527c3ae87017d2c8993816d0e9612d18fc7",
"version": "0.5.1"
}
},
{
Expand Down
2 changes: 1 addition & 1 deletion Package.swift
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ let package = Package(
.package(
name: "CWasm3",
url: "https://github.com/shareup/cwasm3.git",
from: "0.5.0"),
from: "0.5.1"),
.package(
name: "Synchronized",
url: "https://github.com/shareup/synchronized.git",
Expand Down
88 changes: 88 additions & 0 deletions Sources/WasmInterpreter/ImportedFunctionCache.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
import Foundation
import Synchronized
import CWasm3

// MARK: - Managing imported functions

func setImportedFunction(
_ function: @escaping ImportedFunctionSignature,
for context: UnsafeMutableRawPointer,
instanceIdentifier id: UInt64
) {
lock.locked {
if var functionsForID = importedFunctionCache[id] {
functionsForID[context] = function
importedFunctionCache[id] = functionsForID
} else {
let functionsForID = [context: function]
importedFunctionCache[id] = functionsForID
}
}
}

func removeImportedFunction(
for context: UnsafeMutableRawPointer,
instanceIdentifier id: UInt64
) {
lock.locked {
guard var functionsForID = importedFunctionCache[id] else { return }
functionsForID.removeValue(forKey: context)
importedFunctionCache[id] = functionsForID
}
}

func removeImportedFunctions(forInstanceIdentifier id: UInt64) {
lock.locked {
_ = importedFunctionCache.removeValue(forKey: id)
}
}

func importedFunction(
for userData: UnsafeMutableRawPointer?,
instanceIdentifier id: UInt64
) -> ImportedFunctionSignature? {
guard let context = userData else { return nil }
return lock.locked { importedFunctionCache[id]?[context] }
}

func handleImportedFunction(
_ runtime: UnsafeMutablePointer<M3Runtime>?,
_ context: UnsafeMutablePointer<M3ImportContext>?,
_ stackPointer: UnsafeMutablePointer<UInt64>?,
_ heap: UnsafeMutableRawPointer?
) -> UnsafeRawPointer? {
guard let id = m3_GetUserData(runtime)?.load(as: UInt64.self)
else { return UnsafeRawPointer(m3Err_trapUnreachable) }

guard let userData = context?.pointee.userdata
else { return UnsafeRawPointer(m3Err_trapUnreachable) }

guard let function = importedFunction(for: userData, instanceIdentifier: id)
else { return UnsafeRawPointer(m3Err_trapUnreachable) }

return function(stackPointer, heap)
}

// MARK: - Generating instance identifiers

var nextInstanceIdentifier: UInt64 {
lock.locked {
lastInstanceIdentifier += 1
return lastInstanceIdentifier
}
}

func makeRawPointer(for id: UInt64) -> UnsafeMutableRawPointer {
let ptr = UnsafeMutableRawPointer.allocate(
byteCount: MemoryLayout<UInt64>.size,
alignment: MemoryLayout<UInt64>.alignment
)
ptr.storeBytes(of: id, as: UInt64.self)
return ptr
}

// MARK: - Private

private let lock = Lock()
private var lastInstanceIdentifier: UInt64 = 0
private var importedFunctionCache = [UInt64: [UnsafeMutableRawPointer: ImportedFunctionSignature]]()
90 changes: 30 additions & 60 deletions Sources/WasmInterpreter/WasmInterpreter.swift
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,18 @@ import CWasm3
import Synchronized

public final class WasmInterpreter {
private var _environment: IM3Environment
private var _runtime: IM3Runtime
private var _moduleAndBytes: (IM3Module, [UInt8])
private var _module: IM3Module { _moduleAndBytes.0 }
private var id: UInt64
private var idPointer: UnsafeMutableRawPointer

private var _functionCache = [String: IM3Function]()
private var _importedFunctionContexts = [UnsafeMutableRawPointer]()
private var environment: IM3Environment
private var runtime: IM3Runtime
private var moduleAndBytes: (IM3Module, [UInt8])
private var module: IM3Module { moduleAndBytes.0 }

private let _lock = Lock()
private var functionCache = [String: IM3Function]()
private var importedFunctionContexts = [UnsafeMutableRawPointer]()

private let lock = Lock()

public convenience init(module: URL) throws {
try self.init(stackSize: 512 * 1024, module: module)
Expand All @@ -26,11 +29,14 @@ public final class WasmInterpreter {
}

public init(stackSize: UInt32, module bytes: [UInt8]) throws {
id = nextInstanceIdentifier
idPointer = makeRawPointer(for: id)

guard let environment = m3_NewEnvironment() else {
throw WasmInterpreterError.couldNotLoadEnvironment
}

guard let runtime = m3_NewRuntime(environment, stackSize, nil) else {
guard let runtime = m3_NewRuntime(environment, stackSize, idPointer) else {
throw WasmInterpreterError.couldNotLoadRuntime
}

Expand All @@ -39,15 +45,16 @@ public final class WasmInterpreter {
guard let module = mod else { throw WasmInterpreterError.couldNotParseModule }
try WasmInterpreter.check(m3_LoadModule(runtime, module))

_environment = environment
_runtime = runtime
_moduleAndBytes = (module, bytes)
self.environment = environment
self.runtime = runtime
moduleAndBytes = (module, bytes)
}

deinit {
m3_FreeRuntime(_runtime)
m3_FreeEnvironment(_environment)
removeImportedFunctions(for: _importedFunctionContexts)
m3_FreeRuntime(runtime)
m3_FreeEnvironment(environment)
removeImportedFunctions(forInstanceIdentifier: id)
idPointer.deallocate()
}
}

Expand Down Expand Up @@ -150,7 +157,7 @@ extension WasmInterpreter {
let totalBytes = UnsafeMutablePointer<UInt32>.allocate(capacity: 1)
defer { totalBytes.deallocate() }

guard let bytesPointer = m3_GetMemory(_runtime, totalBytes, 0)
guard let bytesPointer = m3_GetMemory(runtime, totalBytes, 0)
else { throw WasmInterpreterError.invalidMemoryAccess }

return Heap(pointer: bytesPointer, size: Int(totalBytes.pointee))
Expand Down Expand Up @@ -205,36 +212,36 @@ extension WasmInterpreter {
else { throw WasmInterpreterError.couldNotGenerateFunctionContext }

do {
setImportedFunction(handler, for: context)
setImportedFunction(handler, for: context, instanceIdentifier: id)
try WasmInterpreter.check(
m3_LinkRawFunctionEx(
_module,
module,
namespace,
name,
signature,
handleImportedFunction,
context
)
)
_lock.locked { _importedFunctionContexts.append(context) }
lock.locked { importedFunctionContexts.append(context) }
} catch {
removeImportedFunction(for: context)
removeImportedFunction(for: context, instanceIdentifier: id)
throw error
}
}
}

extension WasmInterpreter {
func function(named name: String) throws -> IM3Function {
return try _lock.locked { () throws -> IM3Function in
if let compiledFunction = _functionCache[name] {
return try lock.locked { () throws -> IM3Function in
if let compiledFunction = functionCache[name] {
return compiledFunction
} else {
var f: IM3Function?
try WasmInterpreter.check(m3_FindFunction(&f, _runtime, name))
try WasmInterpreter.check(m3_FindFunction(&f, runtime, name))
guard let function = f
else { throw WasmInterpreterError.couldNotFindFunction(name) }
_functionCache[name] = function
functionCache[name] = function
return function
}
}
Expand Down Expand Up @@ -279,40 +286,3 @@ extension WasmInterpreter {
}
}
}

private let importedFunctionLock = Lock()
private var contextToImportedFunction = Dictionary<UnsafeMutableRawPointer, ImportedFunctionSignature>()

private func setImportedFunction(_ function: @escaping ImportedFunctionSignature, for context: UnsafeMutableRawPointer) {
importedFunctionLock.locked { contextToImportedFunction[context] = function }
}

private func removeImportedFunction(for context: UnsafeMutableRawPointer) {
importedFunctionLock.locked { _ = contextToImportedFunction.removeValue(forKey: context) }
}

private func removeImportedFunctions(for contexts: [UnsafeMutableRawPointer]) {
importedFunctionLock.locked { contexts.forEach { contextToImportedFunction.removeValue(forKey: $0) } }
}

private func importedFunction(
for userData: UnsafeMutableRawPointer?
) -> ImportedFunctionSignature? {
guard let context = userData else { return nil }
return importedFunctionLock.locked { contextToImportedFunction[context] }
}

private func handleImportedFunction(
_ runtime: UnsafeMutablePointer<M3Runtime>?,
_ context: UnsafeMutablePointer<M3ImportContext>?,
_ stackPointer: UnsafeMutablePointer<UInt64>?,
_ heap: UnsafeMutableRawPointer?
) -> UnsafeRawPointer? {
guard let userData = context?.pointee.userdata
else { return UnsafeRawPointer(m3Err_trapUnreachable) }

guard let function = importedFunction(for: userData)
else { return UnsafeRawPointer(m3Err_trapUnreachable) }

return function(stackPointer, heap)
}
9 changes: 8 additions & 1 deletion Tests/WasmInterpreterTests/Resources/constant.wat
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,11 @@
(export "constant_1" (func 0))
(export "constant_2" (func 0))
(export "constant_3" (func 0))
(export "constant_4" (func 0)))
(export "constant_4" (func 0))
(export "constant_5" (func 0))
(export "constant_6" (func 0))
(export "constant_7" (func 0))
(export "constant_8" (func 0))
(export "constant_9" (func 0))
(export "constant_10" (func 0))
(export "constant_11" (func 0)))
18 changes: 3 additions & 15 deletions Tests/WasmInterpreterTests/Wasm Modules/ConstantModule.swift
Original file line number Diff line number Diff line change
Expand Up @@ -8,25 +8,13 @@ public struct ConstantModule {
_vm = try WasmInterpreter(module: ConstantModule.wasm)
}

func constant1() throws -> Int {
return Int(try _vm.call("constant_1") as Int32)
}

func constant2() throws -> Int {
return Int(try _vm.call("constant_2") as Int32)
}

func constant3() throws -> Int {
return Int(try _vm.call("constant_3") as Int32)
}

func constant4() throws -> Int {
return Int(try _vm.call("constant_4") as Int32)
func constant(version: Int) throws -> Int {
return Int(try _vm.call("constant_\(version)") as Int32)
}

// `wat2wasm -o >(base64) Tests/WasmInterpreterTests/Resources/constant.wat | pbcopy`
private static var wasm: [UInt8] {
let base64 = "AGFzbQEAAAABBQFgAAF/AwIBAAc1BApjb25zdGFudF8xAAAKY29uc3RhbnRfMgAACmNvbnN0YW50XzMAAApjb25zdGFudF80AAAKCAEGAEGAgAQL"
let base64 = "AGFzbQEAAAABBQFgAAF/AwIBAAeSAQsKY29uc3RhbnRfMQAACmNvbnN0YW50XzIAAApjb25zdGFudF8zAAAKY29uc3RhbnRfNAAACmNvbnN0YW50XzUAAApjb25zdGFudF82AAAKY29uc3RhbnRfNwAACmNvbnN0YW50XzgAAApjb25zdGFudF85AAALY29uc3RhbnRfMTAAAAtjb25zdGFudF8xMQAACggBBgBBgIAECw=="
return Array<UInt8>(Data(base64Encoded: base64)!)
}
}
19 changes: 15 additions & 4 deletions Tests/WasmInterpreterTests/WasmInterpreterTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@ import XCTest
final class WasmInterpreterTests: XCTestCase {
func testCallingTwoFunctionsWithSameImplementation() throws {
let mod = try ConstantModule()
XCTAssertEqual(65536, try mod.constant1())
XCTAssertEqual(65536, try mod.constant2())
XCTAssertEqual(65536, try mod.constant3())
XCTAssertThrowsError(try mod.constant4()) { (error) in

try (1...10).forEach { XCTAssertEqual(65536, try mod.constant(version: $0)) }

XCTAssertThrowsError(try mod.constant(version: 11)) { (error) in
guard case let .wasm3Error(msg) = error as? WasmInterpreterError
else { XCTFail(); return }
XCTAssertEqual("function lookup failed", msg)
Expand Down Expand Up @@ -36,6 +36,16 @@ final class WasmInterpreterTests: XCTestCase {
XCTAssertEqual(-3291, try mod.askModuleToCallImportedFunction())
}

func testConcurrentModulesWithImportedFunctions() throws {
var mod1: ImportedAddModule? = try ImportedAddModule()
let mod2 = try ImportedAddModule()

XCTAssertEqual(-3291, try mod1?.askModuleToCallImportedFunction())
mod1 = nil

XCTAssertEqual(-3291, try mod2.askModuleToCallImportedFunction())
}

func testAccessingAndModifyingHeapMemory() throws {
let mod = try MemoryModule()

Expand Down Expand Up @@ -96,6 +106,7 @@ final class WasmInterpreterTests: XCTestCase {
("testPassingAndReturning32BitValues", testPassingAndReturning32BitValues),
("testPassingAndReturning64BitValues", testPassingAndReturning64BitValues),
("testUsingImportedFunction", testUsingImportedFunction),
("testConcurrentModulesWithImportedFunctions", testConcurrentModulesWithImportedFunctions),
("testAccessingAndModifyingHeapMemory", testAccessingAndModifyingHeapMemory),
("testAccessingInvalidMemoryAddresses", testAccessingInvalidMemoryAddresses),
]
Expand Down

0 comments on commit 3dab007

Please sign in to comment.