Skip to content

Commit

Permalink
v1: WATM v1 driver (#65)
Browse files Browse the repository at this point in the history
* deprecated: use OverrideLogger not SetDefaultXxx

- SetDefaultLogger is now deprecated
- SetDefaultLogHandler is now deprecated

Signed-off-by: Gaukas Wang <[email protected]>

* docs: comments update

Signed-off-by: Gaukas Wang <[email protected]>

* new: DialerManaged interface

- Add DialManaged as a new interface which works like a dialer without remote address as input.
- Naming is tentative.

Signed-off-by: Gaukas Wang <[email protected]>

* update: transport module config as bytes

Starting v1 we seek to use memory-based file system instead of tmpfs for watm config.

Signed-off-by: Gaukas Wang <[email protected]>

* new: Config.DialedAddressValidator

A new util for validating WATM-provided dialing destination address.

Signed-off-by: Gaukas Wang <[email protected]>

* update: rename DialerManaged to Connector

Signed-off-by: Gaukas Wang <[email protected]>

* update: allow cancel context per core

Add additional cancellable wrapper around the input context to allow core-level cancellation.

Signed-off-by: Gaukas Wang <[email protected]>

* new: WATM Driver v1

v1: rename, reformat, and regulate/standardize the function import/exports.

Signed-off-by: Gaukas Wang <[email protected]>

---------

Signed-off-by: Gaukas Wang <[email protected]>
  • Loading branch information
gaukas authored Apr 8, 2024
1 parent c2975bb commit f420809
Show file tree
Hide file tree
Showing 32 changed files with 4,035 additions and 29 deletions.
38 changes: 29 additions & 9 deletions config.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,14 @@ type Config struct {
// net.Dial(network, address)
NetworkDialerFunc func(network, address string) (net.Conn, error)

// DialedAddressValidator is an optional field that can be set to validate
// the dialed address. It is only used when WATM specifies the remote
// address to dial.
//
// If not set, all addresses are considered invalid. To allow all addresses,
// simply set this field to a function that always returns nil.
DialedAddressValidator func(network, address string) error

// NetworkListener specifies a net.listener implementation that listens
// on the specified address on the named network. This optional field
// will be used to provide (incoming) network connections from a
Expand All @@ -44,13 +52,24 @@ type Config struct {
// and/or debugging purposes only.
//
// Caller is supposed to call c.ModuleConfig() to get the pointer to the
// ModuleConfigFactory. If the pointer is nil, a new ModuleConfigFactory will
// ModuleConfigFactory. If this field is unset, a new ModuleConfigFactory will
// be created and returned.
ModuleConfigFactory *WazeroModuleConfigFactory

// RuntimeConfigFactory is used to configure the runtime behavior of
// each WASM instance created. This field is for advanced use cases
// and/or debugging purposes only.
//
// Caller is supposed to call c.RuntimeConfig() to get the pointer to the
// RuntimeConfigFactory. If this field is unset, a new RuntimeConfigFactory will
// be created and returned.
RuntimeConfigFactory *WazeroRuntimeConfigFactory

OverrideLogger *log.Logger // essentially a *slog.Logger, currently using an alias to flatten the version discrepancy
// OverrideLogger is a slog.Logger, used by WATER to log messages including
// debugging information, warnings, errors that cannot be returned to the caller
// of the WATER API. If this field is unset, the default logger from the slog
// package will be used.
OverrideLogger *log.Logger
}

// Clone creates a deep copy of the Config.
Expand All @@ -63,13 +82,14 @@ func (c *Config) Clone() *Config {
copy(wasmClone, c.TransportModuleBin)

return &Config{
TransportModuleBin: wasmClone,
TransportModuleConfig: c.TransportModuleConfig,
NetworkDialerFunc: c.NetworkDialerFunc,
NetworkListener: c.NetworkListener,
ModuleConfigFactory: c.ModuleConfigFactory.Clone(),
RuntimeConfigFactory: c.RuntimeConfigFactory.Clone(),
OverrideLogger: c.OverrideLogger,
TransportModuleBin: wasmClone,
TransportModuleConfig: c.TransportModuleConfig,
NetworkDialerFunc: c.NetworkDialerFunc,
DialedAddressValidator: c.DialedAddressValidator,
NetworkListener: c.NetworkListener,
ModuleConfigFactory: c.ModuleConfigFactory.Clone(),
RuntimeConfigFactory: c.RuntimeConfigFactory.Clone(),
OverrideLogger: c.OverrideLogger,
}
}

Expand Down
2 changes: 1 addition & 1 deletion config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ func testConfigCloneValid(t *testing.T) {
f.Set(reflect.ValueOf(make([]byte, 256)))
case "TransportModuleConfig":
f.Set(reflect.ValueOf(water.TransportModuleConfigFromBytes([]byte("foo"))))
case "NetworkDialerFunc": // functions aren't deeply equal unless nil
case "NetworkDialerFunc", "DialedAddressValidator": // functions aren't deeply equal unless nil
continue
case "NetworkListener":
f.Set(reflect.ValueOf(&net.TCPListener{}))
Expand Down
100 changes: 92 additions & 8 deletions core.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@ package water

import (
"context"
"encoding/binary"
"errors"
"fmt"
"io"
"net"
"os"
"runtime"
Expand All @@ -11,7 +14,11 @@ import (
"github.com/refraction-networking/water/internal/log"
"github.com/tetratelabs/wazero"
"github.com/tetratelabs/wazero/api"
"github.com/tetratelabs/wazero/experimental/sys"
"github.com/tetratelabs/wazero/imports/wasi_snapshot_preview1"

"github.com/karelbilek/wazero-fs-tools/memfs"
expsysfs "github.com/tetratelabs/wazero/experimental/sysfs"
)

var (
Expand All @@ -33,6 +40,9 @@ type Core interface {
// Context returns the base context used by the Core.
Context() context.Context

// ContextCancel cancels the base context used by the Core.
ContextCancel()

// Close closes the Core and releases all the resources
// associated with it.
Close() error
Expand Down Expand Up @@ -107,6 +117,9 @@ type Core interface {
// If the target function is not exported, this function returns an error.
Invoke(funcName string, params ...uint64) (results []uint64, err error)

// ReadIovs reads data from the memory pointed by iovs and writes it to buf.
ReadIovs(iovs, iovsLen int32, buf []byte) (int, error)

// WASIPreview1 enables the WASI preview1 API.
//
// It is recommended that this function only to be invoked if
Expand Down Expand Up @@ -139,10 +152,11 @@ type core struct {
// config
config *Config

ctx context.Context
runtime wazero.Runtime
module wazero.CompiledModule
instance api.Module
ctx context.Context
ctxCancel context.CancelFunc
runtime wazero.Runtime
module wazero.CompiledModule
instance api.Module

// saved after Exports() is called
exportsLoadOnce sync.Once
Expand Down Expand Up @@ -186,7 +200,7 @@ func NewCoreWithContext(ctx context.Context, config *Config) (Core, error) {
importModules: make(map[string]wazero.HostModuleBuilder),
}

c.ctx = ctx
c.ctx, c.ctxCancel = context.WithCancel(ctx)
c.runtime = wazero.NewRuntimeWithConfig(ctx, config.RuntimeConfig().GetConfig())

if c.module, err = c.runtime.CompileModule(ctx, c.config.WATMBinOrPanic()); err != nil {
Expand All @@ -210,6 +224,11 @@ func (c *core) Context() context.Context {
return c.ctx
}

// ContextCancel implements Core.
func (c *core) ContextCancel() {
c.ctxCancel()
}

func (c *core) cleanup() {
for i := range c.importModules {
delete(c.importModules, i)
Expand Down Expand Up @@ -256,6 +275,17 @@ func (c *core) Close() error {
log.LDebugf(c.config.Logger(), "MODULE DROPPED")
}

if c.ctxCancel != nil {
c.ctxCancel()
c.ctxCancel = nil
log.LDebugf(c.config.Logger(), "CONTEXT CANCELED")
}

if c.ctx != nil {
c.ctx = nil // TODO: force dropped
log.LDebugf(c.config.Logger(), "CONTEXT DROPPED")
}

c.cleanup()
})

Expand Down Expand Up @@ -311,10 +341,10 @@ func (c *core) ImportFunction(module, name string, f any) error {
// Unsafe: check if the WebAssembly module really imports this function under
// the given module and name. If not, we warn and skip the import.
if mod, ok := c.ImportedFunctions()[module]; !ok {
log.LDebugf(c.config.Logger(), "water: module %s is not imported.", module)
log.LDebugf(c.config.Logger(), "water: module %s is not imported by the WebAssembly module.", module)
return ErrModuleNotImported
} else if _, ok := mod[name]; !ok {
log.LWarnf(c.config.Logger(), "water: function %s.%s is not imported.", module, name)
log.LWarnf(c.config.Logger(), "water: function %s.%s is not imported by the WebAssembly module.", module, name)
return ErrFuncNotImported
}

Expand Down Expand Up @@ -350,6 +380,28 @@ func (c *core) Instantiate() (err error) {
}
}

// If TransportModuleConfig is set, we pass the config to the runtime.
if c.config.TransportModuleConfig != nil {
mc := c.config.ModuleConfig()
fsCfg := mc.GetFSConfig()
if fsCfg == nil {
fsCfg = wazero.NewFSConfig()

}

memFS := memfs.New()

err := memFS.WriteFile("watm.cfg", c.config.TransportModuleConfig.AsBytes())
if errors.Is(err, nil) || errors.Is(err, sys.Errno(0)) {
return fmt.Errorf("water: memFS.WriteFile returned error: %w", err)
}

if expFsCfg, ok := fsCfg.(expsysfs.FSConfig); ok {
fsCfg = expFsCfg.WithSysFSMount(memFS, "/conf/")
mc.SetFSConfig(fsCfg)
}
}

if c.instance, err = c.runtime.InstantiateModule(
c.ctx,
c.module,
Expand All @@ -373,12 +425,44 @@ func (c *core) Invoke(funcName string, params ...uint64) (results []uint64, err

results, err = expFunc.Call(c.ctx, params...)
if err != nil {
return nil, fmt.Errorf("water: (*wazero.ExportedFunction).Call returned error: %w", err)
return nil, fmt.Errorf("water: (*wazero.ExportedFunction)%q.Call returned error: %w", funcName, err)
}

return
}

var le = binary.LittleEndian

// adapted from fd_write implementation in wazero
func (c *core) ReadIovs(iovs, iovsLen int32, buf []byte) (n int, err error) {
mem := c.instance.Memory()

iovsStop := uint32(iovsLen) << 3 // iovsCount * 8
iovsBuf, ok := mem.Read(uint32(iovs), iovsStop)
if !ok {
return 0, errors.New("ReadIovs: failed to read iovs from memory")
}

for iovsPos := uint32(0); iovsPos < iovsStop; iovsPos += 8 {
offset := le.Uint32(iovsBuf[iovsPos:])
l := le.Uint32(iovsBuf[iovsPos+4:])

b, ok := mem.Read(offset, l)
if !ok {
return 0, errors.New("ReadIovs: failed to read iov from memory")
}

// Write to buf
nCopied := copy(buf[n:], b)
n += nCopied

if nCopied != len(b) {
return n, io.ErrShortBuffer
}
}
return
}

// WASIPreview1 implements Core.
func (c *core) WASIPreview1() error {
if _, err := wasi_snapshot_preview1.Instantiate(c.ctx, c.runtime); err != nil {
Expand Down
81 changes: 78 additions & 3 deletions dialer.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,15 @@ import (
// +----------------+
// Dialer
type Dialer interface {
// Dial dials the remote network address and returns a net.Conn.
// Dial dials the remote network address and returns a
// superset of net.Conn.
//
// It is recommended to use DialContext instead of Dial.
// It is recommended to use DialContext instead of Dial. This
// method may be removed in the future.
Dial(network, address string) (Conn, error)

// DialContext dials the remote network address with the given context
// and returns a net.Conn.
// and returns a superset of net.Conn.
DialContext(ctx context.Context, network, address string) (Conn, error)

mustEmbedUnimplementedDialer()
Expand Down Expand Up @@ -121,3 +123,76 @@ func NewDialerWithContext(ctx context.Context, c *Config) (Dialer, error) {

return nil, ErrDialerVersionNotFound
}

// FixedDialer acts like a dialer, despite the fact that the destination is managed by
// the WebAssembly Transport Module (WATM) instead of specified by the caller.
//
// In other words, FixedDialer is a dialer that does not take network or address as input
// but returns a connection to a remote network address specified by the WATM.
type FixedDialer interface {
// DialFixed dials a remote network address provided by the WATM
// and returns a superset of net.Conn.
//
// It is recommended to use DialFixedContext instead of Connect. This
// method may be removed in the future.
DialFixed() (Conn, error)

// DialFixedContext dials a remote network address provided by the WATM
// with the given context and returns a superset of net.Conn.
DialFixedContext(ctx context.Context) (Conn, error)

mustEmbedUnimplementedFixedDialer()
}

type newFixedDialerFunc func(context.Context, *Config) (FixedDialer, error)

var (
knownFixedDialerVersions = make(map[string]newFixedDialerFunc)

ErrFixedDialerAlreadyRegistered = errors.New("water: free dialer already registered")
ErrFixedDialerVersionNotFound = errors.New("water: free dialer version not found")
ErrUnimplementedFixedDialer = errors.New("water: unimplemented free dialer")

_ FixedDialer = (*UnimplementedFixedDialer)(nil) // type guard
)

// UnimplementedFixedDialer is a FixedDialer that always returns errors.
//
// It is used to ensure forward compatibility of the FixedDialer interface.
type UnimplementedFixedDialer struct{}

// Connect implements FixedDialer.DialFixed().
func (*UnimplementedFixedDialer) DialFixed() (Conn, error) {
return nil, ErrUnimplementedFixedDialer
}

// DialFixedContext implements FixedDialer.DialFixedContext().
func (*UnimplementedFixedDialer) DialFixedContext(_ context.Context) (Conn, error) {
return nil, ErrUnimplementedFixedDialer
}

func (*UnimplementedFixedDialer) mustEmbedUnimplementedFixedDialer() {} //nolint:unused

func RegisterWATMFixedDialer(name string, dialer newFixedDialerFunc) error {
if _, ok := knownFixedDialerVersions[name]; ok {
return ErrFixedDialerAlreadyRegistered
}
knownFixedDialerVersions[name] = dialer
return nil
}

func NewFixedDialerWithContext(ctx context.Context, cfg *Config) (FixedDialer, error) {
core, err := NewCoreWithContext(ctx, cfg)
if err != nil {
return nil, err
}

// Sniff the version of the dialer
for exportName := range core.Exports() {
if f, ok := knownFixedDialerVersions[exportName]; ok {
return f(ctx, cfg)
}
}

return nil, ErrFixedDialerVersionNotFound
}
6 changes: 5 additions & 1 deletion dialer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import (
"net"

"github.com/refraction-networking/water"
_ "github.com/refraction-networking/water/transport/v0"
_ "github.com/refraction-networking/water/transport/v1"
)

// ExampleDialer demonstrates how to use water.Dialer.
Expand Down Expand Up @@ -66,6 +66,10 @@ func ExampleDialer() {
panic("short read")
}

if err := waterConn.Close(); err != nil {
panic(err)
}

fmt.Println(string(buf[:n]))
// Output: olleh
}
Expand Down
Loading

0 comments on commit f420809

Please sign in to comment.