diff --git a/client.go b/client.go index b8954cd4..83354e71 100644 --- a/client.go +++ b/client.go @@ -869,7 +869,7 @@ func (c *Client) ActivateSession(ctx context.Context, s *Session) error { }, ClientSoftwareCertificates: nil, LocaleIDs: s.cfg.LocaleIDs, - UserIdentityToken: ua.NewExtensionObject(s.cfg.UserIdentityToken), + UserIdentityToken: ua.NewExtensionObject(s.cfg.UserIdentityToken, extensionObjectTypeID(s.cfg.UserIdentityToken)), UserTokenSignature: s.cfg.UserTokenSignature, } return c.SecureChannel().SendRequest(ctx, req, s.resp.AuthenticationToken, func(v interface{}) error { @@ -895,6 +895,23 @@ func (c *Client) ActivateSession(ctx context.Context, s *Session) error { }) } +func extensionObjectTypeID(v interface{}) *ua.ExpandedNodeID { + switch v.(type) { + case *ua.AnonymousIdentityToken: + return ua.NewFourByteExpandedNodeID(0, id.AnonymousIdentityToken_Encoding_DefaultBinary) + case *ua.UserNameIdentityToken: + return ua.NewFourByteExpandedNodeID(0, id.UserNameIdentityToken_Encoding_DefaultBinary) + case *ua.X509IdentityToken: + return ua.NewFourByteExpandedNodeID(0, id.X509IdentityToken_Encoding_DefaultBinary) + case *ua.IssuedIdentityToken: + return ua.NewFourByteExpandedNodeID(0, id.IssuedIdentityToken_Encoding_DefaultBinary) + case *ua.ServerStatusDataType: + return ua.NewFourByteExpandedNodeID(0, id.ServerStatusDataType_Encoding_DefaultBinary) + default: + return ua.NewTwoByteExpandedNodeID(0) + } +} + // CloseSession closes the current session. // // See Part 4, 5.6.4 diff --git a/ua/activate_session_request_test.go b/ua/activate_session_request_test.go index 2a482d93..a81dfac8 100644 --- a/ua/activate_session_request_test.go +++ b/ua/activate_session_request_test.go @@ -6,6 +6,8 @@ package ua import ( "testing" + + "github.com/gopcua/opcua/id" ) func TestActivateSessionRequest(t *testing.T) { @@ -17,7 +19,7 @@ func TestActivateSessionRequest(t *testing.T) { ClientSignature: &SignatureData{}, ClientSoftwareCertificates: nil, LocaleIDs: nil, - UserIdentityToken: NewExtensionObject(&AnonymousIdentityToken{PolicyID: "anonymous"}), + UserIdentityToken: NewExtensionObject(&AnonymousIdentityToken{PolicyID: "anonymous"}, NewFourByteExpandedNodeID(0, id.AnonymousIdentityToken_Encoding_DefaultBinary)), UserTokenSignature: &SignatureData{}, }, Bytes: flatten( diff --git a/ua/cancel_request_test.go b/ua/cancel_request_test.go index fb0b6929..fbfcbb1c 100644 --- a/ua/cancel_request_test.go +++ b/ua/cancel_request_test.go @@ -21,7 +21,7 @@ func TestCancelRequest(t *testing.T) { }), Timestamp: time.Date(2018, time.August, 10, 23, 0, 0, 0, time.UTC), RequestHandle: 1, - AdditionalHeader: NewExtensionObject(nil), + AdditionalHeader: NewExtensionObject(nil, NewTwoByteExpandedNodeID(0)), }, RequestHandle: 1, }, diff --git a/ua/cancel_response_test.go b/ua/cancel_response_test.go index 711f47f4..f74149ef 100644 --- a/ua/cancel_response_test.go +++ b/ua/cancel_response_test.go @@ -19,7 +19,7 @@ func TestCancelResponse(t *testing.T) { RequestHandle: 1, ServiceDiagnostics: &DiagnosticInfo{}, StringTable: []string{}, - AdditionalHeader: NewExtensionObject(nil), + AdditionalHeader: NewExtensionObject(nil, NewTwoByteExpandedNodeID(0)), }, CancelCount: 1, }, diff --git a/ua/close_secure_channel_request_test.go b/ua/close_secure_channel_request_test.go index 7732aa0c..063c5d64 100644 --- a/ua/close_secure_channel_request_test.go +++ b/ua/close_secure_channel_request_test.go @@ -21,7 +21,7 @@ func TestCloseSecureChannelRequest(t *testing.T) { }), Timestamp: time.Date(2018, time.August, 10, 23, 0, 0, 0, time.UTC), RequestHandle: 1, - AdditionalHeader: NewExtensionObject(nil), + AdditionalHeader: NewExtensionObject(nil, NewTwoByteExpandedNodeID(0)), }, }, Bytes: []byte{ diff --git a/ua/close_secure_channel_response_test.go b/ua/close_secure_channel_response_test.go index 8ccd2e10..eb802ccf 100644 --- a/ua/close_secure_channel_response_test.go +++ b/ua/close_secure_channel_response_test.go @@ -19,7 +19,7 @@ func TestCloseSecureChannelResponse(t *testing.T) { RequestHandle: 1, ServiceDiagnostics: &DiagnosticInfo{}, StringTable: []string{}, - AdditionalHeader: NewExtensionObject(nil), + AdditionalHeader: NewExtensionObject(nil, NewTwoByteExpandedNodeID(0)), }, }, Bytes: []byte{ diff --git a/ua/close_session_request_test.go b/ua/close_session_request_test.go index a20cd6db..f9bc8971 100644 --- a/ua/close_session_request_test.go +++ b/ua/close_session_request_test.go @@ -21,7 +21,7 @@ func TestCloseSessionRequest(t *testing.T) { }), Timestamp: time.Date(2018, time.August, 10, 23, 0, 0, 0, time.UTC), RequestHandle: 1, - AdditionalHeader: NewExtensionObject(nil), + AdditionalHeader: NewExtensionObject(nil, NewTwoByteExpandedNodeID(0)), }, DeleteSubscriptions: true, }, diff --git a/ua/close_session_response_test.go b/ua/close_session_response_test.go index 85ce7813..298ae6a8 100644 --- a/ua/close_session_response_test.go +++ b/ua/close_session_response_test.go @@ -19,7 +19,7 @@ func TestCloseSessionResponse(t *testing.T) { RequestHandle: 1, ServiceDiagnostics: &DiagnosticInfo{}, StringTable: []string{}, - AdditionalHeader: NewExtensionObject(nil), + AdditionalHeader: NewExtensionObject(nil, NewTwoByteExpandedNodeID(0)), }, }, Bytes: []byte{ diff --git a/ua/codec_test.go b/ua/codec_test.go index 24fe252a..1174b66b 100644 --- a/ua/codec_test.go +++ b/ua/codec_test.go @@ -21,8 +21,7 @@ type CodecTestCase struct { Bytes []byte } -// RunCodecTest tests encoding, decoding and length calclulation for the given -// object. +// RunCodecTest tests encoding, decoding and length calculation for the given object. func RunCodecTest(t *testing.T, cases []CodecTestCase) { t.Helper() diff --git a/ua/create_session_request_test.go b/ua/create_session_request_test.go index ae13dd1a..17da8411 100644 --- a/ua/create_session_request_test.go +++ b/ua/create_session_request_test.go @@ -21,7 +21,7 @@ func TestCreateSessionRequest(t *testing.T) { }), Timestamp: time.Date(2018, time.August, 10, 23, 0, 0, 0, time.UTC), RequestHandle: 1, - AdditionalHeader: NewExtensionObject(nil), + AdditionalHeader: NewExtensionObject(nil, NewTwoByteExpandedNodeID(0)), }, ClientDescription: &ApplicationDescription{ ApplicationURI: "app-uri", diff --git a/ua/create_session_response_test.go b/ua/create_session_response_test.go index 4e1c361b..29b5839d 100644 --- a/ua/create_session_response_test.go +++ b/ua/create_session_response_test.go @@ -19,7 +19,7 @@ func TestCreateSessionResponse(t *testing.T) { RequestHandle: 1, ServiceDiagnostics: &DiagnosticInfo{}, StringTable: []string{}, - AdditionalHeader: NewExtensionObject(nil), + AdditionalHeader: NewExtensionObject(nil, NewTwoByteExpandedNodeID(0)), }, SessionID: NewNumericNodeID(0, 1), AuthenticationToken: NewByteStringNodeID(0, []byte{ diff --git a/ua/create_subscription_response_test.go b/ua/create_subscription_response_test.go index f980a8eb..f2386301 100644 --- a/ua/create_subscription_response_test.go +++ b/ua/create_subscription_response_test.go @@ -19,7 +19,7 @@ func TestCreateSubscriptionResponse(t *testing.T) { RequestHandle: 1, ServiceDiagnostics: &DiagnosticInfo{}, StringTable: []string{}, - AdditionalHeader: NewExtensionObject(nil), + AdditionalHeader: NewExtensionObject(nil, NewTwoByteExpandedNodeID(0)), }, SubscriptionID: 1, RevisedPublishingInterval: 1000, diff --git a/ua/decode.go b/ua/decode.go index 0a679d5d..2414bc0c 100644 --- a/ua/decode.go +++ b/ua/decode.go @@ -43,6 +43,7 @@ func decode(b []byte, val reflect.Value, name string) (n int, err error) { }() } + // fmt.Printf("decode: %s is a %s\n", name, val.Kind()) buf := NewBuffer(b) switch { case isBinaryDecoder(val): @@ -51,7 +52,6 @@ func decode(b []byte, val reflect.Value, name string) (n int, err error) { case isTime(val): val.Set(reflect.ValueOf(buf.ReadTime()).Convert(val.Type())) default: - // fmt.Printf("decode: %s is a %s\n", name, val.Kind()) switch val.Kind() { case reflect.Bool: val.SetBool(buf.ReadBool()) diff --git a/ua/extension_object.go b/ua/extension_object.go index 4ee7dafd..6151c08c 100644 --- a/ua/extension_object.go +++ b/ua/extension_object.go @@ -5,21 +5,33 @@ package ua import ( + "fmt" + "reflect" + "github.com/gopcua/opcua/debug" - "github.com/gopcua/opcua/id" ) // eotypes contains all known extension objects. -var eotypes = NewTypeRegistry() +var eotypes = NewFuncRegistry() // RegisterExtensionObject registers a new extension object type. // It panics if the type or the id is already registered. func RegisterExtensionObject(typeID *NodeID, v interface{}) { - if err := eotypes.Register(typeID, v); err != nil { + RegisterExtensionObjectFunc(typeID, DefaultEncodeExtensionObject, DefaultDecodeExtensionObject(v)) +} + +// RegisterExtensionObjectFunc registers a new extension object type using encode and decode functions +// It panics if the type or the id is already registered. +func RegisterExtensionObjectFunc(typeID *NodeID, ef encodefunc, df decodefunc) { + if err := eotypes.Register(typeID, ef, df); err != nil { panic("Extension object " + err.Error()) } } +func Deregister(typeID *NodeID) { + eotypes.Deregister(typeID) +} + // These flags define the value type of an ExtensionObject. // They cannot be combined. const ( @@ -28,6 +40,9 @@ const ( ExtensionObjectXML = 2 ) +type encodefunc func(v interface{}) ([]byte, error) +type decodefunc func(b []byte, v interface{}) error + // ExtensionObject is encoded as sequence of bytes prefixed by the NodeId of its DataTypeEncoding // and the number of bytes encoded. // @@ -38,15 +53,16 @@ type ExtensionObject struct { Value interface{} } -func NewExtensionObject(value interface{}) *ExtensionObject { +func NewExtensionObject(value interface{}, typeID *ExpandedNodeID) *ExtensionObject { e := &ExtensionObject{ - TypeID: ExtensionObjectTypeID(value), + TypeID: typeID, Value: value, } e.UpdateMask() return e } +// Decode fails if there is no decode func registered for e func (e *ExtensionObject) Decode(b []byte) (int, error) { buf := NewBuffer(b) e.TypeID = new(ExpandedNodeID) @@ -74,21 +90,57 @@ func (e *ExtensionObject) Decode(b []byte) (int, error) { } typeID := e.TypeID.NodeID - e.Value = eotypes.New(typeID) - if e.Value == nil { + decode := eotypes.DecodeFunc(typeID) + if decode == nil { debug.Printf("ua: unknown extension object %s", typeID) return buf.Pos(), buf.Error() } - - body.ReadStruct(e.Value) + err := decode(body.Bytes(), &e.Value) + if err != nil { + // TODO: we are losing Pos by creating new buf in decode? + return buf.Pos(), err + } + // TODO: we are losing Pos by creating new buf in decode? return buf.Pos(), body.Error() } +// Encode falls back to defaultencode if there is no encode func registered for e func (e *ExtensionObject) Encode() ([]byte, error) { - buf := NewBuffer(nil) if e == nil { e = &ExtensionObject{TypeID: NewTwoByteExpandedNodeID(0), EncodingMask: ExtensionObjectEmpty} } + + typeID := e.TypeID.NodeID + encode := eotypes.EncodeFunc(typeID) + if encode == nil { + debug.Printf("ua: unknown extension object %s", typeID) + return DefaultEncodeExtensionObject(e) + } + return encode(e) +} + +// DefaultDecode creates a new instance of v and decodes into it +func DefaultDecodeExtensionObject(v interface{}) func([]byte, interface{}) error { + return func(b []byte, vv interface{}) error { + rv := reflect.ValueOf(vv) + if rv.Kind() != reflect.Pointer || rv.IsNil() { + return fmt.Errorf("incorrect type to decode into") + } + r := reflect.New(reflect.TypeOf(v).Elem()).Interface() + buf := NewBuffer(b) + buf.ReadStruct(r) + reflect.Indirect(rv).Set(reflect.ValueOf(r)) + return nil + } +} + +// DefaultEncode encodes into bytes based on the go struct +func DefaultEncodeExtensionObject(v interface{}) ([]byte, error) { + e, ok := v.(*ExtensionObject) + if !ok { + return nil, fmt.Errorf("expected ExtensionObject") + } + buf := NewBuffer(nil) buf.WriteStruct(e.TypeID) buf.WriteByte(e.EncodingMask) if e.EncodingMask == ExtensionObjectEmpty { @@ -114,23 +166,3 @@ func (e *ExtensionObject) UpdateMask() { e.EncodingMask = ExtensionObjectBinary } } - -func ExtensionObjectTypeID(v interface{}) *ExpandedNodeID { - switch v.(type) { - case *AnonymousIdentityToken: - return NewFourByteExpandedNodeID(0, id.AnonymousIdentityToken_Encoding_DefaultBinary) - case *UserNameIdentityToken: - return NewFourByteExpandedNodeID(0, id.UserNameIdentityToken_Encoding_DefaultBinary) - case *X509IdentityToken: - return NewFourByteExpandedNodeID(0, id.X509IdentityToken_Encoding_DefaultBinary) - case *IssuedIdentityToken: - return NewFourByteExpandedNodeID(0, id.IssuedIdentityToken_Encoding_DefaultBinary) - case *ServerStatusDataType: - return NewFourByteExpandedNodeID(0, id.ServerStatusDataType_Encoding_DefaultBinary) - default: - if id := eotypes.Lookup(v); id != nil { - return &ExpandedNodeID{NodeID: id} - } - return NewTwoByteExpandedNodeID(0) - } -} diff --git a/ua/extension_object_test.go b/ua/extension_object_test.go index 6b6a39e9..15ed1f45 100644 --- a/ua/extension_object_test.go +++ b/ua/extension_object_test.go @@ -6,13 +6,15 @@ package ua import ( "testing" + + "github.com/gopcua/opcua/id" ) func TestExtensionObject(t *testing.T) { cases := []CodecTestCase{ { Name: "anonymous-user-identity-token", - Struct: NewExtensionObject(&AnonymousIdentityToken{PolicyID: "anonymous"}), + Struct: NewExtensionObject(&AnonymousIdentityToken{PolicyID: "anonymous"}, NewFourByteExpandedNodeID(0, id.AnonymousIdentityToken_Encoding_DefaultBinary)), Bytes: []byte{ // TypeID 0x01, 0x00, 0x41, 0x01, diff --git a/ua/find_servers_on_network_request_test.go b/ua/find_servers_on_network_request_test.go index 5e9a89f7..ba7630cb 100644 --- a/ua/find_servers_on_network_request_test.go +++ b/ua/find_servers_on_network_request_test.go @@ -21,7 +21,7 @@ func TestFindServersOnNetworkRequest(t *testing.T) { }), Timestamp: time.Date(2018, time.August, 10, 23, 0, 0, 0, time.UTC), RequestHandle: 1, - AdditionalHeader: NewExtensionObject(nil), + AdditionalHeader: NewExtensionObject(nil, NewTwoByteExpandedNodeID(0)), }, StartingRecordID: 1000, MaxRecordsToReturn: 0, @@ -34,7 +34,7 @@ func TestFindServersOnNetworkRequest(t *testing.T) { // 0xa6, 0x43, 0xf8, 0x77, 0x7b, 0xc6, 0x2f, 0xc8, // }), // time.Date(2018, time.August, 10, 23, 0, 0, 0, time.UTC), - // 1, 0, 0, "", NewExtensionObject(nil), + // 1, 0, 0, "", NewExtensionObject(nil,nil), // ), // 1000, // 0, diff --git a/ua/find_servers_on_network_response_test.go b/ua/find_servers_on_network_response_test.go index eb239a0e..39ec75e1 100644 --- a/ua/find_servers_on_network_response_test.go +++ b/ua/find_servers_on_network_response_test.go @@ -19,7 +19,7 @@ func TestFindServersOnNetworkResponse(t *testing.T) { RequestHandle: 1, ServiceDiagnostics: &DiagnosticInfo{}, StringTable: []string{}, - AdditionalHeader: NewExtensionObject(nil), + AdditionalHeader: NewExtensionObject(nil, NewTwoByteExpandedNodeID(0)), }, LastCounterResetTime: time.Date(2018, time.August, 10, 23, 0, 0, 0, time.UTC), Servers: []*ServerOnNetwork{ @@ -72,7 +72,7 @@ func TestFindServersOnNetworkResponse(t *testing.T) { RequestHandle: 1, ServiceDiagnostics: &DiagnosticInfo{}, StringTable: []string{}, - AdditionalHeader: NewExtensionObject(nil), + AdditionalHeader: NewExtensionObject(nil, NewTwoByteExpandedNodeID(0)), }, LastCounterResetTime: time.Date(2018, time.August, 10, 23, 0, 0, 0, time.UTC), Servers: []*ServerOnNetwork{ diff --git a/ua/find_servers_request_test.go b/ua/find_servers_request_test.go index 53892968..9ef530ed 100644 --- a/ua/find_servers_request_test.go +++ b/ua/find_servers_request_test.go @@ -21,7 +21,7 @@ func TestFindServersRequest(t *testing.T) { }), Timestamp: time.Date(2018, time.August, 10, 23, 0, 0, 0, time.UTC), RequestHandle: 1, - AdditionalHeader: NewExtensionObject(nil), + AdditionalHeader: NewExtensionObject(nil, NewTwoByteExpandedNodeID(0)), }, EndpointURL: "opc.tcp://wow.its.easy:11111/UA/Server", }, diff --git a/ua/find_servers_response_test.go b/ua/find_servers_response_test.go index a2ca6eea..96b14f00 100644 --- a/ua/find_servers_response_test.go +++ b/ua/find_servers_response_test.go @@ -19,7 +19,7 @@ func TestFindServersResponse(t *testing.T) { RequestHandle: 1, ServiceDiagnostics: &DiagnosticInfo{}, StringTable: []string{}, - AdditionalHeader: NewExtensionObject(nil), + AdditionalHeader: NewExtensionObject(nil, NewTwoByteExpandedNodeID(0)), }, Servers: []*ApplicationDescription{ { diff --git a/ua/get_endpoints_request_test.go b/ua/get_endpoints_request_test.go index 60a8f8c2..4228bc63 100644 --- a/ua/get_endpoints_request_test.go +++ b/ua/get_endpoints_request_test.go @@ -21,7 +21,7 @@ func TestGetEndpointsRequest(t *testing.T) { }), Timestamp: time.Date(2018, time.August, 10, 23, 0, 0, 0, time.UTC), RequestHandle: 1, - AdditionalHeader: NewExtensionObject(nil), + AdditionalHeader: NewExtensionObject(nil, NewTwoByteExpandedNodeID(0)), }, EndpointURL: "opc.tcp://wow.its.easy:11111/UA/Server", }, diff --git a/ua/get_endpoints_response_test.go b/ua/get_endpoints_response_test.go index 24ee35c5..a21818c5 100644 --- a/ua/get_endpoints_response_test.go +++ b/ua/get_endpoints_response_test.go @@ -19,7 +19,7 @@ func TestGetEndpointsResponse(t *testing.T) { RequestHandle: 1, ServiceDiagnostics: &DiagnosticInfo{}, StringTable: []string{}, - AdditionalHeader: NewExtensionObject(nil), + AdditionalHeader: NewExtensionObject(nil, NewTwoByteExpandedNodeID(0)), }, Endpoints: []*EndpointDescription{ { diff --git a/ua/open_secure_channel_request_test.go b/ua/open_secure_channel_request_test.go index 8d1f01fb..bf61efdd 100644 --- a/ua/open_secure_channel_request_test.go +++ b/ua/open_secure_channel_request_test.go @@ -21,7 +21,7 @@ func TestOpenSecureChannelRequest(t *testing.T) { }), Timestamp: time.Date(2018, time.August, 10, 23, 0, 0, 0, time.UTC), RequestHandle: 1, - AdditionalHeader: NewExtensionObject(nil), + AdditionalHeader: NewExtensionObject(nil, NewTwoByteExpandedNodeID(0)), }, ClientProtocolVersion: 0, RequestType: SecurityTokenRequestTypeIssue, diff --git a/ua/open_secure_channel_response_test.go b/ua/open_secure_channel_response_test.go index 4cc10748..d2ccf8b2 100644 --- a/ua/open_secure_channel_response_test.go +++ b/ua/open_secure_channel_response_test.go @@ -19,7 +19,7 @@ func TestOpenSecureChannelResponse(t *testing.T) { RequestHandle: 1, ServiceDiagnostics: &DiagnosticInfo{}, StringTable: []string{}, - AdditionalHeader: NewExtensionObject(nil), + AdditionalHeader: NewExtensionObject(nil, NewTwoByteExpandedNodeID(0)), }, ServerProtocolVersion: 0, SecurityToken: &ChannelSecurityToken{ diff --git a/ua/read_request_test.go b/ua/read_request_test.go index d5adf6d2..bf1b6eaa 100644 --- a/ua/read_request_test.go +++ b/ua/read_request_test.go @@ -21,7 +21,7 @@ func TestReadRequest(t *testing.T) { }), Timestamp: time.Date(2018, time.August, 10, 23, 0, 0, 0, time.UTC), RequestHandle: 1, - AdditionalHeader: NewExtensionObject(nil), + AdditionalHeader: NewExtensionObject(nil, NewTwoByteExpandedNodeID(0)), }, MaxAge: 0, TimestampsToReturn: TimestampsToReturnBoth, diff --git a/ua/read_response_test.go b/ua/read_response_test.go index 5b3f5d85..ac9e4768 100644 --- a/ua/read_response_test.go +++ b/ua/read_response_test.go @@ -19,7 +19,7 @@ func TestReadResponse(t *testing.T) { RequestHandle: 1, ServiceDiagnostics: &DiagnosticInfo{}, StringTable: []string{}, - AdditionalHeader: NewExtensionObject(nil), + AdditionalHeader: NewExtensionObject(nil, NewTwoByteExpandedNodeID(0)), }, Results: []*DataValue{ { diff --git a/ua/request_header_test.go b/ua/request_header_test.go index ccd1a43d..223bd426 100644 --- a/ua/request_header_test.go +++ b/ua/request_header_test.go @@ -13,7 +13,7 @@ func NewNullRequestHeader() *RequestHeader { return &RequestHeader{ AuthenticationToken: NewTwoByteNodeID(0), Timestamp: time.Date(1970, 1, 1, 0, 0, 0, 0, time.UTC), - AdditionalHeader: NewExtensionObject(nil), + AdditionalHeader: NewExtensionObject(nil, NewTwoByteExpandedNodeID(0)), } } diff --git a/ua/response_header_test.go b/ua/response_header_test.go index f59fdd76..e84b7e3a 100644 --- a/ua/response_header_test.go +++ b/ua/response_header_test.go @@ -13,7 +13,7 @@ func NewNullResponseHeader() *ResponseHeader { return &ResponseHeader{ Timestamp: time.Date(1970, 1, 1, 0, 0, 0, 0, time.UTC), ServiceDiagnostics: &DiagnosticInfo{}, - AdditionalHeader: NewExtensionObject(nil), + AdditionalHeader: NewExtensionObject(nil, NewTwoByteExpandedNodeID(0)), } } diff --git a/ua/service.go b/ua/service.go index 01236d10..44d6c8e7 100644 --- a/ua/service.go +++ b/ua/service.go @@ -6,30 +6,51 @@ package ua import ( "fmt" + "reflect" "github.com/gopcua/opcua/debug" ) // svcreg contains all known service request/response objects. -var svcreg = NewTypeRegistry() +var svcreg = NewFuncRegistry() +var svctypeids = map[reflect.Type]uint16{} // RegisterService registers a new service object type. // It panics if the type or the id is already registered. func RegisterService(typeID uint16, v interface{}) { - if err := svcreg.Register(NewFourByteNodeID(0, typeID), v); err != nil { + ef := func(vv interface{}) ([]byte, error) { + // TODO check/ensure vv is of type v ? + // TODO + return Encode(vv) + } + df := func(b []byte, vv interface{}) error { + rv := reflect.ValueOf(vv) + if rv.Kind() != reflect.Pointer || rv.IsNil() { + return fmt.Errorf("incorrect type to decode into") + } + r := reflect.New(reflect.TypeOf(v).Elem()).Interface() + buf := NewBuffer(b) + buf.ReadStruct(r) + reflect.Indirect(rv).Set(reflect.ValueOf(r)) + return nil + } + nodeID := NewFourByteExpandedNodeID(0, typeID).NodeID + if err := svcreg.Register(nodeID, ef, df); err != nil { panic("Service " + err.Error()) } + typ := reflect.TypeOf(v) + svctypeids[typ] = uint16(nodeID.IntID()) } // ServiceTypeID returns the id of the service object type as // registered with RegisterService. If the service object is not // known the function returns 0. func ServiceTypeID(v interface{}) uint16 { - id := svcreg.Lookup(v) - if id == nil { + id, ok := svctypeids[reflect.TypeOf(v)] + if !ok { return 0 } - return uint16(id.IntID()) + return id } func DecodeService(b []byte) (*ExpandedNodeID, interface{}, error) { @@ -40,15 +61,16 @@ func DecodeService(b []byte) (*ExpandedNodeID, interface{}, error) { } b = b[n:] - v := svcreg.New(typeID.NodeID) - if v == nil { + decode := svcreg.DecodeFunc(typeID.NodeID) + if decode == nil { return nil, nil, StatusBadServiceUnsupported } if debug.FlagSet("packet") { - fmt.Printf("%T: %#v\n", v, b) + fmt.Printf("%T: %#v\n", decode, b) } - _, err = Decode(b, v) + var v interface{} + err = decode(b, &v) return typeID, v, err } diff --git a/ua/typereg.go b/ua/typereg.go index 0cf1242f..9edd34f1 100644 --- a/ua/typereg.go +++ b/ua/typereg.go @@ -5,92 +5,101 @@ package ua import ( - "reflect" "sync" "github.com/gopcua/opcua/errors" ) -// TypeRegistry provides a registry for Go types. -// -// Each type is registered with a unique identifier -// which cannot be changed for the lifetime of the component. -// -// Types can be registered multiple times under different -// identifiers. -// -// The implementation is safe for concurrent use. -type TypeRegistry struct { - mu sync.Mutex - types map[string]reflect.Type - ids map[reflect.Type]string +type FuncRegistry struct { + mu sync.Mutex + encodeFuncs map[string]encodefunc + decodeFuncs map[string]decodefunc } -// NewTypeRegistry returns a new type registry. -func NewTypeRegistry() *TypeRegistry { - return &TypeRegistry{ - types: make(map[string]reflect.Type), - ids: make(map[reflect.Type]string), +// NewFuncRegistry returns a new func registry. +func NewFuncRegistry() *FuncRegistry { + return &FuncRegistry{ + encodeFuncs: make(map[string]encodefunc), + decodeFuncs: make(map[string]decodefunc), } } -// New returns a new instance of the type with the given id. +// EncodeFunc returns the function registered to encode Node with ID id // // If the id is not known the function returns nil. // // New panics if id is nil. -func (r *TypeRegistry) New(id *NodeID) interface{} { +func (r *FuncRegistry) EncodeFunc(id *NodeID) encodefunc { if id == nil { - panic("opcua: missing id in call to TypeRegistry.New") + panic("opcua: missing id in call to FuncRegistry.New") } r.mu.Lock() defer r.mu.Unlock() - typ, ok := r.types[id.String()] + f, ok := r.encodeFuncs[id.String()] if !ok { return nil } - return reflect.New(typ.Elem()).Interface() + return f } -// Lookup returns the id of the type of v or nil if -// the type is not registered. +// DecodeFunc returns the function registered to decode Node with ID id // -// If the type was registered multiple times the first -// registered id for this type is returned. -func (r *TypeRegistry) Lookup(v interface{}) *NodeID { +// If the id is not known the function returns nil. +// +// New panics if id is nil. +func (r *FuncRegistry) DecodeFunc(id *NodeID) decodefunc { + if id == nil { + panic("opcua: missing id in call to FuncRegistry.New") + } + r.mu.Lock() defer r.mu.Unlock() - if id, ok := r.ids[reflect.TypeOf(v)]; ok { - return MustParseNodeID(id) + + f, ok := r.decodeFuncs[id.String()] + if !ok { + return nil } - return nil + return f } -// Register adds a new type to the registry. +// Register adds a new node to the registry. // -// If the id is already registered as a different type the function returns an error. +// If the id is already registered the function returns an error. // // Register panics if id is nil. -func (r *TypeRegistry) Register(id *NodeID, v interface{}) error { +func (r *FuncRegistry) Register(id *NodeID, ef encodefunc, df decodefunc) error { if id == nil { - panic("opcua: missing id in call to TypeRegistry.Register") + panic("opcua: missing id in call to FuncRegistry.Register") } r.mu.Lock() defer r.mu.Unlock() - typ := reflect.TypeOf(v) ids := id.String() - if cur := r.types[ids]; cur != nil && cur != typ { - return errors.Errorf("%s is already registered as %v", id, cur) + if cur := r.encodeFuncs[ids]; cur != nil { + return errors.Errorf("%s is already registered", id) } - r.types[ids] = typ + r.encodeFuncs[ids] = ef - if _, exists := r.ids[typ]; !exists { - r.ids[typ] = ids + if _, exists := r.decodeFuncs[ids]; !exists { + r.decodeFuncs[ids] = df } return nil } + +// Deregister removes a node from the registry +func (r *FuncRegistry) Deregister(id *NodeID) { + if id == nil { + panic("opcua: missing id in call to FuncRegistry.Register") + } + + r.mu.Lock() + defer r.mu.Unlock() + + ids := id.String() + delete(r.encodeFuncs, ids) + delete(r.decodeFuncs, ids) +} diff --git a/ua/variant_test.go b/ua/variant_test.go index df072d68..65ddf406 100644 --- a/ua/variant_test.go +++ b/ua/variant_test.go @@ -11,6 +11,7 @@ import ( "time" "github.com/gopcua/opcua/errors" + "github.com/gopcua/opcua/id" "github.com/pascaldekloe/goe/verify" ) @@ -259,7 +260,7 @@ func TestVariant(t *testing.T) { { Name: "ExtensionObject", Struct: MustVariant(NewExtensionObject( - &AnonymousIdentityToken{PolicyID: "anonymous"}, + &AnonymousIdentityToken{PolicyID: "anonymous"}, NewFourByteExpandedNodeID(0, id.AnonymousIdentityToken_Encoding_DefaultBinary), )), Bytes: []byte{ // variant encoding mask @@ -291,7 +292,7 @@ func TestVariant(t *testing.T) { }, SecondsTillShutdown: 0, ShutdownReason: NewLocalizedText(""), - }, + }, NewFourByteExpandedNodeID(0, id.ServerStatusDataType_Encoding_DefaultBinary), )), Bytes: []byte{ // variant encoding mask diff --git a/ua/write_request_test.go b/ua/write_request_test.go index 83cecf3d..371b70c0 100644 --- a/ua/write_request_test.go +++ b/ua/write_request_test.go @@ -21,7 +21,7 @@ func TestWriteRequest(t *testing.T) { }), Timestamp: time.Date(2018, time.August, 10, 23, 0, 0, 0, time.UTC), RequestHandle: 1, - AdditionalHeader: NewExtensionObject(nil), + AdditionalHeader: NewExtensionObject(nil, NewTwoByteExpandedNodeID(0)), }, NodesToWrite: []*WriteValue{ { @@ -79,7 +79,7 @@ func TestWriteRequest(t *testing.T) { }), Timestamp: time.Date(2018, time.August, 10, 23, 0, 0, 0, time.UTC), RequestHandle: 1, - AdditionalHeader: NewExtensionObject(nil), + AdditionalHeader: NewExtensionObject(nil, NewTwoByteExpandedNodeID(0)), }, NodesToWrite: []*WriteValue{ { diff --git a/ua/write_response_test.go b/ua/write_response_test.go index d1d4b3fd..52c9eabf 100644 --- a/ua/write_response_test.go +++ b/ua/write_response_test.go @@ -19,7 +19,7 @@ func TestWriteResponse(t *testing.T) { RequestHandle: 1, ServiceDiagnostics: &DiagnosticInfo{}, StringTable: []string{}, - AdditionalHeader: NewExtensionObject(nil), + AdditionalHeader: NewExtensionObject(nil, NewTwoByteExpandedNodeID(0)), }, Results: []StatusCode{StatusOK}, }, @@ -51,7 +51,7 @@ func TestWriteResponse(t *testing.T) { RequestHandle: 1, ServiceDiagnostics: &DiagnosticInfo{}, StringTable: []string{}, - AdditionalHeader: NewExtensionObject(nil), + AdditionalHeader: NewExtensionObject(nil, NewTwoByteExpandedNodeID(0)), }, Results: []StatusCode{StatusOK, StatusBadUserAccessDenied}, }, diff --git a/uasc/message_test.go b/uasc/message_test.go index 1153b612..91e2b6b0 100644 --- a/uasc/message_test.go +++ b/uasc/message_test.go @@ -8,8 +8,6 @@ import ( "testing" "time" - "github.com/gopcua/opcua/id" - "github.com/gopcua/opcua/ua" ) @@ -24,7 +22,7 @@ func TestMessage(t *testing.T) { }, } instance := &channelInstance{ - sc: s, + sc: s, sequenceNumber: 0, securityTokenID: 0, } @@ -35,14 +33,13 @@ func TestMessage(t *testing.T) { Timestamp: time.Date(2018, time.August, 10, 23, 0, 0, 0, time.UTC), RequestHandle: 1, ReturnDiagnostics: 0x03ff, - AdditionalHeader: ua.NewExtensionObject(nil), + AdditionalHeader: ua.NewExtensionObject(nil, ua.NewTwoByteExpandedNodeID(0)), }, ClientProtocolVersion: 0, RequestType: ua.SecurityTokenRequestTypeIssue, SecurityMode: ua.MessageSecurityModeNone, RequestedLifetime: 6000000, }, - id.OpenSecureChannelRequest_Encoding_DefaultBinary, s.nextRequestID(), ) @@ -123,7 +120,7 @@ func TestMessage(t *testing.T) { }, } instance := &channelInstance{ - sc: s, + sc: s, sequenceNumber: 0, securityTokenID: 0, } @@ -134,11 +131,10 @@ func TestMessage(t *testing.T) { Timestamp: time.Date(2018, time.August, 10, 23, 0, 0, 0, time.UTC), RequestHandle: 1, ReturnDiagnostics: 0x03ff, - AdditionalHeader: ua.NewExtensionObject(nil), + AdditionalHeader: ua.NewExtensionObject(nil, ua.NewTwoByteExpandedNodeID(0)), }, EndpointURL: "opc.tcp://wow.its.easy:11111/UA/Server", }, - id.GetEndpointsRequest_Encoding_DefaultBinary, s.nextRequestID(), ) @@ -194,7 +190,7 @@ func TestMessage(t *testing.T) { }, } instance := &channelInstance{ - sc: s, + sc: s, sequenceNumber: 0, securityTokenID: 0, } @@ -205,10 +201,9 @@ func TestMessage(t *testing.T) { Timestamp: time.Date(2018, time.August, 10, 23, 0, 0, 0, time.UTC), RequestHandle: 1, ReturnDiagnostics: 0x03ff, - AdditionalHeader: ua.NewExtensionObject(nil), + AdditionalHeader: ua.NewExtensionObject(nil, ua.NewTwoByteExpandedNodeID(0)), }, }, - id.CloseSecureChannelRequest_Encoding_DefaultBinary, s.nextRequestID(), ) diff --git a/uasc/secure_channel_instance.go b/uasc/secure_channel_instance.go index bf123eb8..cd3f8195 100644 --- a/uasc/secure_channel_instance.go +++ b/uasc/secure_channel_instance.go @@ -11,7 +11,6 @@ import ( "time" "github.com/gopcua/opcua/errors" - "github.com/gopcua/opcua/id" "github.com/gopcua/opcua/ua" "github.com/gopcua/opcua/uapolicy" ) @@ -59,8 +58,7 @@ func (c *channelInstance) nextSequenceNumber() uint32 { } func (c *channelInstance) newRequestMessage(req ua.Request, reqID uint32, authToken *ua.NodeID, timeout time.Duration) (*Message, error) { - typeID := ua.ServiceTypeID(req) - if typeID == 0 { + if ua.ServiceTypeID(req) == 0 { return nil, errors.Errorf("unknown service %T. Did you call register?", req) } if authToken == nil { @@ -80,14 +78,14 @@ func (c *channelInstance) newRequestMessage(req ua.Request, reqID uint32, authTo req.SetHeader(reqHdr) // encode the message - return c.newMessage(req, typeID, reqID), nil + return c.newMessage(req, reqID), nil } -func (c *channelInstance) newMessage(srv interface{}, typeID uint16, requestID uint32) *Message { +func (c *channelInstance) newMessage(srv interface{}, requestID uint32) *Message { sequenceNumber := c.nextSequenceNumber() - switch typeID { - case id.OpenSecureChannelRequest_Encoding_DefaultBinary, id.OpenSecureChannelResponse_Encoding_DefaultBinary: + switch srv.(type) { + case *ua.OpenSecureChannelRequest, *ua.OpenSecureChannelResponse: // Do not send the thumbprint for security mode None // even if we have a certificate. // @@ -103,18 +101,18 @@ func (c *channelInstance) newMessage(srv interface{}, typeID uint16, requestID u AsymmetricSecurityHeader: NewAsymmetricSecurityHeader(c.sc.cfg.SecurityPolicyURI, c.sc.cfg.Certificate, thumbprint), SequenceHeader: NewSequenceHeader(sequenceNumber, requestID), }, - TypeID: ua.NewFourByteExpandedNodeID(0, typeID), + TypeID: ua.NewFourByteExpandedNodeID(0, ua.ServiceTypeID(srv)), Service: srv, } - case id.CloseSecureChannelRequest_Encoding_DefaultBinary, id.CloseSecureChannelResponse_Encoding_DefaultBinary: + case *ua.CloseSecureChannelRequest, *ua.CloseSecureChannelResponse: return &Message{ MessageHeader: &MessageHeader{ Header: NewHeader(MessageTypeCloseSecureChannel, ChunkTypeFinal, c.secureChannelID), SymmetricSecurityHeader: NewSymmetricSecurityHeader(c.securityTokenID), SequenceHeader: NewSequenceHeader(sequenceNumber, requestID), }, - TypeID: ua.NewFourByteExpandedNodeID(0, typeID), + TypeID: ua.NewFourByteExpandedNodeID(0, ua.ServiceTypeID(srv)), Service: srv, } @@ -125,7 +123,7 @@ func (c *channelInstance) newMessage(srv interface{}, typeID uint16, requestID u SymmetricSecurityHeader: NewSymmetricSecurityHeader(c.securityTokenID), SequenceHeader: NewSequenceHeader(sequenceNumber, requestID), }, - TypeID: ua.NewFourByteExpandedNodeID(0, typeID), + TypeID: ua.NewFourByteExpandedNodeID(0, ua.ServiceTypeID(srv)), Service: srv, } } diff --git a/uatest/custom_codec_test.go b/uatest/custom_codec_test.go new file mode 100644 index 00000000..3f0aade8 --- /dev/null +++ b/uatest/custom_codec_test.go @@ -0,0 +1,142 @@ +//go:build integration +// +build integration + +package uatest + +import ( + "context" + "encoding/json" + "fmt" + "reflect" + "testing" + + "github.com/gopcua/opcua" + "github.com/gopcua/opcua/ua" + "github.com/pascaldekloe/goe/verify" +) + +func TestReadNodeIDWithDecodeFunc(t *testing.T) { + ctx := context.Background() + + srv := NewServer("read_unknow_node_id_server.py") + defer srv.Close() + + c, err := opcua.NewClient(srv.Endpoint, srv.Opts...) + if err != nil { + t.Fatal(err) + } + if err := c.Connect(ctx); err != nil { + t.Fatal(err) + } + defer c.Close(ctx) + + nodeID := ua.NewStringNodeID(2, "IntValZero") + + decodeFunc := func(b []byte, v interface{}) error { + // decode into map[string]interface, which means + // decode into dynamically generated go type + // then json marshal/unmarshal :) + rv := reflect.ValueOf(v) + if rv.Kind() != reflect.Pointer || rv.IsNil() { + return fmt.Errorf("incorrect type to decode into") + } + r := &struct { + I int64 `json:"i"` + }{} // TODO generate dynamically + buf := ua.NewBuffer(b) + buf.ReadStruct(r) + out := map[string]interface{}{} + b, err := json.Marshal(r) + if err != nil { + return err + } + if err := json.Unmarshal(b, &out); err != nil { + return err + } + reflect.Indirect(rv).Set(reflect.ValueOf(out)) + return nil + } + + ua.RegisterExtensionObjectFunc(ua.NewStringNodeID(2, "IntValType"), nil, decodeFunc) + defer ua.Deregister(ua.NewStringNodeID(2, "IntValType")) + + resp, err := c.Read(ctx, &ua.ReadRequest{ + NodesToRead: []*ua.ReadValueID{ + {NodeID: nodeID}, + }, + }) + if err != nil { + t.Fatal(err) + } + + want := map[string]interface{}{"i": float64(0)} // TODO: float64? yay json! + if got := resp.Results[0].Value.Value().(*ua.ExtensionObject).Value; !reflect.DeepEqual(got, want) { + t.Errorf("got %#v want %#v for a node with an unknown type", got, want) + } +} + +type ExtraComplex struct { + ignore, i, j int64 +} + +// TestCallMethod, but instead of passing Complex{3,8} as an input argument, we pass ExtraComplex{42,3,8} +// We expect the same result only because we register the nodeID for Complex objects with a custom encode func +// Imagine ExtraComplex as a newer version of the API, and encodefunc allows for backwards compatibility +func TestCallMethodWithEncodeFunc(t *testing.T) { + complexNodeID := ua.NewStringNodeID(2, "ComplexType") + + encode := func(v interface{}) ([]byte, error) { + // map ExtraComplex -> Complex, dropping 'ignore' field + e, ok := v.(*ua.ExtensionObject) + if !ok { + return nil, fmt.Errorf("expected extensionobject") + } + // if we have ExtensionObjects for both ExtraComplex and Complex objects sharing the same nodeID, + // then this function will get called for both. Hence the if-statement + if ec, ok := e.Value.(*ExtraComplex); ok { + e.Value = &Complex{ec.i, ec.j} + } + return ua.DefaultEncodeExtensionObject(e) + } + + ua.RegisterExtensionObjectFunc(complexNodeID, encode, nil) + defer ua.Deregister(complexNodeID) + + req := &ua.CallMethodRequest{ + ObjectID: ua.NewStringNodeID(2, "main"), + MethodID: ua.NewStringNodeID(2, "sumOfSquare"), + InputArguments: []*ua.Variant{ + ua.MustVariant(ua.NewExtensionObject(&ExtraComplex{42, 3, 8}, &ua.ExpandedNodeID{NodeID: complexNodeID})), + }, + } + out := []*ua.Variant{ua.MustVariant(int64(9 + 64))} + + ctx := context.Background() + + srv := NewServer("method_server.py") + defer srv.Close() + + c, err := opcua.NewClient(srv.Endpoint, srv.Opts...) + if err != nil { + t.Fatal(err) + } + if err := c.Connect(ctx); err != nil { + t.Fatal(err) + } + defer c.Close(ctx) + + resp, err := c.Call(ctx, req) + if err != nil { + t.Fatal(err) + } + if got, want := resp.StatusCode, ua.StatusOK; got != want { + t.Fatalf("got status %v want %v", got, want) + } + if got, want := resp.OutputArguments, out; !verify.Values(t, "", got, want) { + t.Fail() + } +} + +func TestReadUnregisteredExtensionObject(t *testing.T) { + // TODO ask server for description, then decode anyways? +} diff --git a/uatest/method_test.go b/uatest/method_test.go index 77172714..b10b0867 100644 --- a/uatest/method_test.go +++ b/uatest/method_test.go @@ -18,7 +18,9 @@ type Complex struct { } func TestCallMethod(t *testing.T) { - ua.RegisterExtensionObject(ua.NewStringNodeID(2, "ComplexType"), new(Complex)) + complexNodeID := ua.NewStringNodeID(2, "ComplexType") + //ua.RegisterExtensionObject(complexNodeID, new(Complex)) + //defer ua.Deregister(complexNodeID) tests := []struct { req *ua.CallMethodRequest @@ -49,7 +51,7 @@ func TestCallMethod(t *testing.T) { ObjectID: ua.NewStringNodeID(2, "main"), MethodID: ua.NewStringNodeID(2, "sumOfSquare"), InputArguments: []*ua.Variant{ - ua.MustVariant(ua.NewExtensionObject(&Complex{3, 8})), + ua.MustVariant(ua.NewExtensionObject(&Complex{3, 8}, &ua.ExpandedNodeID{NodeID: complexNodeID})), }, }, out: []*ua.Variant{ua.MustVariant(int64(9 + 64))}, diff --git a/uatest/read_unknow_node_id_test.go b/uatest/read_unknow_node_id_test.go index b5fd4249..5a63de0b 100644 --- a/uatest/read_unknow_node_id_test.go +++ b/uatest/read_unknow_node_id_test.go @@ -32,6 +32,7 @@ func TestReadUnknowNodeID(t *testing.T) { // read node with unknown extension object // This should be OK nodeWithUnknownType := ua.NewStringNodeID(2, "IntValZero") + resp, err := c.Read(ctx, &ua.ReadRequest{ NodesToRead: []*ua.ReadValueID{ {NodeID: nodeWithUnknownType},