diff --git a/wampv2/dealer.go b/wampv2/dealer.go new file mode 100644 index 0000000..528ce80 --- /dev/null +++ b/wampv2/dealer.go @@ -0,0 +1,149 @@ +package wampv2 + +type Callee interface { + ErrorHandler + // Acknowledge that the endpoint was succesfully registered + SendRegistered(*Registered) + // Acknowledge that the endpoint was succesfully unregistered + SendUnregistered(*Unregistered) + // Dealer requests fulfillment of a procedure call + SendInvocation(*Invocation) +} + +type Caller interface { + ErrorHandler + // Dealer sends the returned result from the procedure call + SendResult(*Result) +} + +type Dealer interface { + // Register a procedure on an endpoint + Register(Callee, *Register) + // Unregister a procedure on an endpoint + Unregister(Callee, *Unregister) + // Call a procedure on an endpoint + Call(Caller, *Call) + // Return the result of a procedure call + Yield(Callee, *Yield) +} + +type RemoteProcedure struct { + Endpoint Callee + Procedure URI +} + +type DefaultDealer struct { + // map registration IDs to procedures + procedures map[ID]RemoteProcedure + // map procedure URIs to registration IDs + // TODO: this will eventually need to be `map[URI][]ID` to support + // multiple callees for the same procedure + registrations map[URI]ID + // keep track of call IDs so we can send the response to the caller + calls map[ID]Caller + // link the invocation ID to the call ID + invocations map[ID]ID +} + +func NewDefaultDealer() *DefaultDealer { + return &DefaultDealer{ + procedures: make(map[ID]RemoteProcedure), + registrations: make(map[URI]ID), + calls: make(map[ID]Caller), + invocations: make(map[ID]ID), + } +} + +func (d *DefaultDealer) Register(callee Callee, msg *Register) { + if _, ok := d.registrations[msg.Procedure]; ok { + callee.SendError(&Error{ + Type: msg.MessageType(), + Request: msg.Request, + Error: WAMP_ERROR_PROCEDURE_ALREADY_EXISTS, + }) + return + } + reg := NewID() + d.procedures[reg] = RemoteProcedure{callee, msg.Procedure} + d.registrations[msg.Procedure] = reg + callee.SendRegistered(&Registered{ + Request: msg.Request, + Registration: reg, + }) +} + +func (d *DefaultDealer) Unregister(callee Callee, msg *Unregister) { + if procedure, ok := d.procedures[msg.Registration]; !ok { + // the registration doesn't exist + callee.SendError(&Error{ + Type: msg.MessageType(), + Request: msg.Request, + Error: WAMP_ERROR_NO_SUCH_REGISTRATION, + }) + } else { + delete(d.registrations, procedure.Procedure) + delete(d.procedures, msg.Registration) + callee.SendUnregistered(&Unregistered{ + Request: msg.Request, + }) + } +} + +func (d *DefaultDealer) Call(caller Caller, msg *Call) { + if reg, ok := d.registrations[msg.Procedure]; !ok { + caller.SendError(&Error{ + Type: msg.MessageType(), + Request: msg.Request, + Error: WAMP_ERROR_NO_SUCH_PROCEDURE, + }) + } else { + if rproc, ok := d.procedures[reg]; !ok { + // found a registration id, but doesn't match any remote procedure + caller.SendError(&Error{ + Type: msg.MessageType(), + Request: msg.Request, + // TODO: what should this error be? + Error: URI("wamp.error.internal_error"), + }) + } else { + // everything checks out, make the invocation request + d.calls[msg.Request] = caller + invocationID := NewID() + d.invocations[invocationID] = msg.Request + rproc.Endpoint.SendInvocation(&Invocation{ + Request: invocationID, + Registration: reg, + Arguments: msg.Arguments, + ArgumentsKw: msg.ArgumentsKw, + }) + } + } +} + +func (d *DefaultDealer) Yield(callee Callee, msg *Yield) { + if callID, ok := d.invocations[msg.Request]; !ok { + callee.SendError(&Error{ + Type: msg.MessageType(), + Request: msg.Request, + // TODO: what should this error be? + Error: URI("wamp.error.no_such_invocation"), + }) + } else { + if caller, ok := d.calls[callID]; !ok { + // found the invocation id, but doesn't match any call id + callee.SendError(&Error{ + Type: msg.MessageType(), + Request: msg.Request, + // TODO: what should this error be? + Error: URI("wamp.error.no_such_call"), + }) + } else { + // return the result to the caller + caller.SendResult(&Result{ + Request: callID, + Arguments: msg.Arguments, + ArgumentsKw: msg.ArgumentsKw, + }) + } + } +} diff --git a/wampv2/dealer_test.go b/wampv2/dealer_test.go new file mode 100644 index 0000000..3070d80 --- /dev/null +++ b/wampv2/dealer_test.go @@ -0,0 +1,46 @@ +package wampv2 + +import ( + . "github.com/smartystreets/goconvey/convey" + "testing" +) + +type TestCallee struct { + received Message +} + +func (c *TestCallee) SendError(msg *Error) { c.received = msg } +func (c *TestCallee) SendRegistered(msg *Registered) { c.received = msg } +func (c *TestCallee) SendUnregistered(msg *Unregistered) { c.received = msg } +func (c *TestCallee) SendInvocation(msg *Invocation) { c.received = msg } + +func TestRegister(t *testing.T) { + Convey("Registering a procedure", t, func() { + dealer := NewDefaultDealer() + callee := &TestCallee{} + testProcedure := URI("turnpike.test.endpoint") + msg := &Register{Request: 123, Procedure: testProcedure} + dealer.Register(callee, msg) + + Convey("The callee should have received a REGISTERED message", func() { + reg := callee.received.(*Registered).Registration + So(reg, ShouldNotEqual, 0) + }) + + Convey("The dealer should have the endpoint registered", func() { + reg := callee.received.(*Registered).Registration + reg2, ok := dealer.registrations[testProcedure] + So(ok, ShouldBeTrue) + So(reg, ShouldEqual, reg2) + proc, ok := dealer.procedures[reg] + So(ok, ShouldBeTrue) + So(proc.Procedure, ShouldEqual, testProcedure) + }) + + Convey("The same procedure cannot be registered more than once", func() { + msg := &Register{Request: 321, Procedure: testProcedure} + dealer.Register(callee, msg) + So(callee.received, ShouldHaveSameTypeAs, &Error{}) + }) + }) +} diff --git a/wampv2/realm.go b/wampv2/realm.go index 0bf93f6..47c6e58 100644 --- a/wampv2/realm.go +++ b/wampv2/realm.go @@ -5,6 +5,9 @@ type Realm interface { // Broker returns a custom broker for this realm. // If this is nil, the default broker will be used. Broker() Broker + // Dealer returns a custom dealer for this realm. + // If this is nil, the default dealer will be used. + Dealer() Dealer } type DefaultRealm struct { @@ -17,3 +20,7 @@ func NewDefaultRealm() *DefaultRealm { func (realm *DefaultRealm) Broker() Broker { return nil } + +func (realm *DefaultRealm) Dealer() Dealer { + return nil +} diff --git a/wampv2/router.go b/wampv2/router.go index 1975628..2f8596b 100644 --- a/wampv2/router.go +++ b/wampv2/router.go @@ -31,6 +31,7 @@ type Router interface { // DefaultRouter is a very basic WAMP router. type DefaultRouter struct { *DefaultBroker + *DefaultDealer realms map[URI]Realm clients map[URI][]Session @@ -41,6 +42,7 @@ type DefaultRouter struct { func NewDefaultRouter() *DefaultRouter { return &DefaultRouter{ DefaultBroker: NewDefaultBroker(), + DefaultDealer: NewDefaultDealer(), realms: make(map[URI]Realm), clients: make(map[URI][]Session), } @@ -71,6 +73,13 @@ func (r *DefaultRouter) broker(realm URI) Broker { return r } +func (r *DefaultRouter) dealer(realm URI) Dealer { + if d := r.realms[realm].Dealer(); d != nil { + return d + } + return r +} + func (r *DefaultRouter) handleSession(sess Session, realm URI) { defer sess.Close() @@ -100,34 +109,45 @@ func (r *DefaultRouter) handleSession(sess Session, realm URI) { if pub, ok := sess.Endpoint.(Publisher); ok { r.broker(realm).Publish(pub, v) } else { - err := &Error{ - Type: v.MessageType(), - Request: v.Request, - Error: WAMP_ERROR_NOT_AUTHORIZED, - } - sess.Send(err) + r.invalidSessionError(sess, v, v.Request) } case *Subscribe: if sub, ok := sess.Endpoint.(Subscriber); ok { r.broker(realm).Subscribe(sub, v) } else { - err := &Error{ - Type: v.MessageType(), - Request: v.Request, - Error: WAMP_ERROR_NOT_AUTHORIZED, - } - sess.Send(err) + r.invalidSessionError(sess, v, v.Request) } case *Unsubscribe: if sub, ok := sess.Endpoint.(Subscriber); ok { r.broker(realm).Unsubscribe(sub, v) } else { - err := &Error{ - Type: v.MessageType(), - Request: v.Request, - Error: WAMP_ERROR_NOT_AUTHORIZED, - } - sess.Send(err) + r.invalidSessionError(sess, v, v.Request) + } + + // Dealer messages + case *Register: + if callee, ok := sess.Endpoint.(Callee); ok { + r.dealer(realm).Register(callee, v) + } else { + r.invalidSessionError(sess, v, v.Request) + } + case *Unregister: + if callee, ok := sess.Endpoint.(Callee); ok { + r.dealer(realm).Unregister(callee, v) + } else { + r.invalidSessionError(sess, v, v.Request) + } + case *Call: + if caller, ok := sess.Endpoint.(Caller); ok { + r.dealer(realm).Call(caller, v) + } else { + r.invalidSessionError(sess, v, v.Request) + } + case *Yield: + if callee, ok := sess.Endpoint.(Callee); ok { + r.dealer(realm).Yield(callee, v) + } else { + r.invalidSessionError(sess, v, v.Request) } default: @@ -136,6 +156,14 @@ func (r *DefaultRouter) handleSession(sess Session, realm URI) { } } +func (r *DefaultRouter) invalidSessionError(sess Session, msg Message, req ID) { + sess.Send(&Error{ + Type: msg.MessageType(), + Request: req, + Error: WAMP_ERROR_NOT_AUTHORIZED, + }) +} + func (r *DefaultRouter) Accept(ep Endpoint) error { if r.closing { ep.Send(&Abort{Reason: WAMP_ERROR_SYSTEM_SHUTDOWN}) diff --git a/wampv2/router_test.go b/wampv2/router_test.go index 138d5e1..4adae12 100644 --- a/wampv2/router_test.go +++ b/wampv2/router_test.go @@ -3,7 +3,7 @@ package wampv2 import "testing" import "time" -const test_realm = URI("test.realm") +const testRealm = URI("test.realm") type basicEndpoint struct { *localEndpoint @@ -20,25 +20,23 @@ func (ep *basicEndpoint) Send(msg Message) error { } // satisfy ErrorHandler -func (ep *basicEndpoint) SendError(msg *Error) { - ep.Send(msg) -} +func (ep *basicEndpoint) SendError(msg *Error) { ep.Send(msg) } // satisfy Publisher -func (ep *basicEndpoint) SendPublished(msg *Published) { - ep.Send(msg) -} +func (ep *basicEndpoint) SendPublished(msg *Published) { ep.Send(msg) } // satisfy Subscriber -func (ep *basicEndpoint) SendEvent(msg *Event) { - ep.Send(msg) -} -func (ep *basicEndpoint) SendUnsubscribed(msg *Unsubscribed) { - ep.Send(msg) -} -func (ep *basicEndpoint) SendSubscribed(msg *Subscribed) { - ep.Send(msg) -} +func (ep *basicEndpoint) SendEvent(msg *Event) { ep.Send(msg) } +func (ep *basicEndpoint) SendUnsubscribed(msg *Unsubscribed) { ep.Send(msg) } +func (ep *basicEndpoint) SendSubscribed(msg *Subscribed) { ep.Send(msg) } + +// satisfy Callee +func (ep *basicEndpoint) SendRegistered(msg *Registered) { ep.Send(msg) } +func (ep *basicEndpoint) SendUnregistered(msg *Unregistered) { ep.Send(msg) } +func (ep *basicEndpoint) SendInvocation(msg *Invocation) { ep.Send(msg) } + +// satisfy Caller +func (ep *basicEndpoint) SendResult(msg *Result) { ep.Send(msg) } func (ep *basicEndpoint) Close() error { close(ep.outgoing) @@ -47,9 +45,9 @@ func (ep *basicEndpoint) Close() error { func basicConnect(t *testing.T, ep *basicEndpoint, server Endpoint) *DefaultRouter { r := NewDefaultRouter() - r.RegisterRealm(test_realm, NewDefaultRealm()) + r.RegisterRealm(testRealm, NewDefaultRealm()) - ep.Send(&Hello{Realm: test_realm}) + ep.Send(&Hello{Realm: testRealm}) if err := r.Accept(server); err != nil { t.Fatal(err) } @@ -59,7 +57,7 @@ func basicConnect(t *testing.T, ep *basicEndpoint, server Endpoint) *DefaultRout } if msg := <-ep.incoming; msg.MessageType() != WELCOME { - t.Fatal("Expected first message sent to be a wescome message") + t.Fatal("Expected first message sent to be a welcome message") } return r } @@ -159,7 +157,7 @@ func TestPublishAcknowledge(t *testing.T) { } func TestSubscribe(t *testing.T) { - const test_topic = URI("some.uri") + const testTopic = URI("some.uri") subClient, subServer := pipe() sub := &basicEndpoint{subClient} @@ -167,7 +165,7 @@ func TestSubscribe(t *testing.T) { defer r.Close() subscribeId := NewID() - sub.Send(&Subscribe{Request: subscribeId, Topic: test_topic}) + sub.Send(&Subscribe{Request: subscribeId, Topic: testTopic}) var subscriptionId ID select { @@ -185,12 +183,12 @@ func TestSubscribe(t *testing.T) { pubClient, pubServer := pipe() pub := &basicEndpoint{pubClient} - pub.Send(&Hello{Realm: test_realm}) + pub.Send(&Hello{Realm: testRealm}) if err := r.Accept(&basicEndpoint{pubServer}); err != nil { t.Fatal("Error pubing publisher") } - pub_id := NewID() - pub.Send(&Publish{Request: pub_id, Topic: test_topic}) + pubId := NewID() + pub.Send(&Publish{Request: pubId, Topic: testTopic}) select { case <-time.After(time.Millisecond): @@ -204,3 +202,72 @@ func TestSubscribe(t *testing.T) { // TODO: check Details, Arguments, ArgumentsKw } } + +type basicCallee struct{} + +func TestCall(t *testing.T) { + const testProcedure = URI("turnpike.test.endpoint") + calleeClient, calleeServer := pipe() + callee := &basicEndpoint{calleeClient} + r := basicConnect(t, callee, &basicEndpoint{calleeServer}) + defer r.Close() + + registerId := NewID() + // callee registers remote procedure + callee.Send(&Register{Request: registerId, Procedure: testProcedure}) + + var registrationId ID + select { + case <-time.After(time.Millisecond): + t.Fatal("Timed out waiting for REGISTERED") + case msg := <-callee.incoming: + if registered, ok := msg.(*Registered); !ok { + t.Fatalf("Expected REGISTERED, but received %s instead: %+v", msg.MessageType(), msg) + } else if registered.Request != registerId { + t.Fatalf("Request id does not match the one sent: %d != %d", registered.Request, registerId) + } else { + registrationId = registered.Registration + } + } + + callerClient, callerServer := pipe() + caller := &basicEndpoint{callerClient} + caller.Send(&Hello{Realm: testRealm}) + if err := r.Accept(&basicEndpoint{callerServer}); err != nil { + t.Fatal("Error connecting caller") + } + if msg := <-caller.incoming; msg.MessageType() != WELCOME { + t.Fatal("Expected first message sent to be a welcome message") + } + callId := NewID() + // caller calls remote procedure + caller.Send(&Call{Request: callId, Procedure: testProcedure}) + + var invocationId ID + select { + case <-time.After(time.Millisecond): + t.Fatal("Timed out waiting for INVOCATION") + case msg := <-callee.incoming: + if invocation, ok := msg.(*Invocation); !ok { + t.Errorf("Expected INVOCATION, but received %s instead: %+v", msg.MessageType(), msg) + } else if invocation.Registration != registrationId { + t.Errorf("Registration id does not match the one assigned: %d != %d", invocation.Registration, registrationId) + } else { + invocationId = invocation.Request + } + } + + // callee returns result of remove procedure + callee.Send(&Yield{Request: invocationId}) + + select { + case <-time.After(time.Millisecond): + t.Fatal("Timed out waiting for RESULT") + case msg := <-caller.incoming: + if result, ok := msg.(*Result); !ok { + t.Errorf("Expected RESULT, but received %s instead: %+v", msg.MessageType(), msg) + } else if result.Request != callId { + t.Errorf("Result id does not match the call id: %d != %d", result.Request, callId) + } + } +} diff --git a/wampv2/util.go b/wampv2/util.go index c2e8dc6..8618730 100644 --- a/wampv2/util.go +++ b/wampv2/util.go @@ -19,7 +19,7 @@ const ( WAMP_ERROR_GOODBYE_AND_OUT = URI("wamp.error.goodbye_and_out") // A Dealer could not perform a call, since the procedure called does not exist. - WAMP_NO_SUCH_PROCEDURE = URI("wamp.error.no_such_procedure") + WAMP_ERROR_NO_SUCH_PROCEDURE = URI("wamp.error.no_such_procedure") // A Broker could not perform a unsubscribe, since the given subscription is not active. WAMP_ERROR_NO_SUCH_SUBSCRIPTION = URI("wamp.error.no_such_subscription") @@ -34,7 +34,7 @@ const ( WAMP_ERROR_INVALID_TOPIC = URI("wamp.error.invalid_topic") // A procedure could not be registered, since a procedure with the given URI is already registered (and the Dealer is not able to set up a distributed registration). - WAMP_ERROR_ERROR_PROCEDURE_ALREADY_EXISTS = URI("wamp.error.procedure_already_exists") + WAMP_ERROR_PROCEDURE_ALREADY_EXISTS = URI("wamp.error.procedure_already_exists") ) const ( diff --git a/wampv2/websocket_test.go b/wampv2/websocket_test.go index 6be8073..4a21a03 100644 --- a/wampv2/websocket_test.go +++ b/wampv2/websocket_test.go @@ -12,7 +12,7 @@ import ( func newWebsocketServer(t *testing.T) (int, Router, io.Closer) { r := NewDefaultRouter() - r.RegisterRealm(test_realm, NewDefaultRealm()) + r.RegisterRealm(testRealm, NewDefaultRealm()) s := NewWebsocketServer(r) s.RegisterProtocol(jsonWebsocketProtocol, websocket.TextMessage, new(JSONSerializer)) s.RegisterProtocol(msgpackWebsocketProtocol, websocket.BinaryMessage, new(MessagePackSerializer)) @@ -38,7 +38,7 @@ func TestWSHandshakeJSON(t *testing.T) { t.Fatal(err) } - ep.Send(&Hello{Realm: test_realm}) + ep.Send(&Hello{Realm: testRealm}) go r.Accept(ep) if msg, ok := <-ep.Receive(); !ok { @@ -57,7 +57,7 @@ func TestWSHandshakeMsgpack(t *testing.T) { t.Fatal(err) } - ep.Send(&Hello{Realm: test_realm}) + ep.Send(&Hello{Realm: testRealm}) go r.Accept(ep) if msg, ok := <-ep.Receive(); !ok {