From 0173adf5a11cdcc10222d98c6534982ce0b15aa3 Mon Sep 17 00:00:00 2001 From: AlexanderMescheryakov Date: Tue, 28 Mar 2023 22:22:43 +0400 Subject: [PATCH 1/3] CCE-M3: Hook to extract VPC Endpoint ID --- diam/message.go | 28 ++++++++++++++++++---------- diam/server.go | 22 ++++++++++++++++------ 2 files changed, 34 insertions(+), 16 deletions(-) diff --git a/diam/message.go b/diam/message.go index 73c2d72..96ae6fe 100644 --- a/diam/message.go +++ b/diam/message.go @@ -64,19 +64,19 @@ func readerBufferSlice(buf *bytes.Buffer, l int) []byte { // ReadMessage reads a binary stream from the reader and uses the given // dictionary to parse it. -func ReadMessage(reader io.Reader, dictionary *dict.Parser) (*Message, error) { +func ReadMessage(reader io.Reader, dictionary *dict.Parser, hook *HeaderReaderHook) (*Message, *string, error) { buf := newReaderBuffer() defer putReaderBuffer(buf) m := &Message{dictionary: dictionary} - cmd, stream, err := m.readHeader(reader, buf) + cmd, stream, endpointId, err := m.readHeader(reader, buf, hook) if err != nil { - return nil, err + return nil, endpointId, err } m.stream = stream if err = m.readBody(reader, buf, cmd, stream); err != nil { - return nil, err + return nil, endpointId, err } - return m, nil + return m, endpointId, nil } // MessageStream returns the stream #, the message was received on (when applicable) @@ -84,7 +84,14 @@ func (m *Message) MessageStream() uint { return m.stream } -func (m *Message) readHeader(r io.Reader, buf *bytes.Buffer) (cmd *dict.Command, stream uint, err error) { +func (m *Message) readHeader(r io.Reader, buf *bytes.Buffer, hook *HeaderReaderHook) (cmd *dict.Command, stream uint, endpointId *string, err error) { + if (hook != nil) { + endpointId, err = (*hook)(r) + if err != nil { + return nil, stream, nil, err + } + } + b := buf.Bytes()[:HeaderLength] msr, isMulti := r.(MultistreamReader) if isMulti { @@ -96,20 +103,21 @@ func (m *Message) readHeader(r io.Reader, buf *bytes.Buffer) (cmd *dict.Command, _, err = io.ReadFull(r, b) } if err != nil { - return nil, stream, err + return nil, stream, nil, err } + m.Header, err = DecodeHeader(b) if err != nil { - return nil, stream, err + return nil, stream, endpointId, err } cmd, err = m.Dictionary().FindCommand( m.Header.ApplicationID, m.Header.CommandCode, ) if err != nil { - return nil, stream, err + return nil, stream, endpointId, err } - return cmd, stream, nil + return cmd, stream, endpointId, nil } func (m *Message) readBody(r io.Reader, buf *bytes.Buffer, cmd *dict.Command, stream uint) error { diff --git a/diam/server.go b/diam/server.go index 44514bc..dd42689 100644 --- a/diam/server.go +++ b/diam/server.go @@ -43,6 +43,7 @@ type Conn interface { Context() context.Context // Returns the internal context SetContext(ctx context.Context) // Stores a new context Connection() net.Conn // Returns network connection + ProxyHeaderVpcEndpointId() *string // Returns the VPC Endpoint ID fetched by the header hook } // The CloseNotifier interface is implemented by Conns which @@ -87,9 +88,10 @@ type conn struct { tlsState *tls.ConnectionState // or nil when not using TLS writer *response // the diam.Conn exposed to handlers - mu sync.Mutex // guards the following - closeNotifyc chan struct{} - clientGone bool + mu sync.Mutex // guards the following + closeNotifyc chan struct{} + clientGone bool + vpcEndpointId *string // the AWS VPC endpoint ID extracted from the proxy header } func (c *conn) closeNotify() <-chan struct{} { @@ -165,13 +167,14 @@ func (c *conn) readMessage() (m *Message, err error) { if msc, isMulti := c.rwc.(MultistreamConn); isMulti { // If it's a multi-stream association - reset the stream to "undefined" prior to reading next message msc.ResetCurrentStream() - m, err = ReadMessage(msc, c.dictionary()) // MultistreamConn has it's own buffering + m, vpcEndpointId, err = ReadMessage(msc, c.dictionary(), c.server.HeaderHook) // MultistreamConn has it's own buffering } else { - m, err = ReadMessage(c.buf.Reader, c.dictionary()) + m, vpcEndpointId, err = ReadMessage(c.buf.Reader, c.dictionary(), c.server.HeaderHook) } if err != nil { return nil, err } + c.vpcEndpointId = vpcEndpointId return m, nil } @@ -338,6 +341,10 @@ func (w *response) Connection() net.Conn { return w.conn.rwc } +func (w *response) ProxyHeaderVpcEndpointId() *string { + return w.conn.vpcEndpointId +} + // The HandlerFunc type is an adapter to allow the use of // ordinary functions as diameter handlers. If f is a function // with the appropriate signature, HandlerFunc(f) is a @@ -557,6 +564,8 @@ func Serve(l net.Listener, handler Handler) error { return srv.Serve(l) } +type HeaderReaderHook = func(io.Reader) (*string, error) + // A Server defines parameters for running a diameter server. type Server struct { Network string // network of the address - empty string defaults to tcp @@ -567,7 +576,8 @@ type Server struct { WriteTimeout time.Duration // maximum duration before timing out write of the response TLSConfig *tls.Config // optional TLS config, used by ListenAndServeTLS LocalAddr net.Addr // optional Local Address to bind dailer's (Dail...) socket to -} + HeaderHook +}/ // serverHandler delegates to either the server's Handler or DefaultServeMux. type serverHandler struct { From 58f01c4ed4dd97e47b1d8fe93fb167792f4a4a5b Mon Sep 17 00:00:00 2001 From: Ahmed Anas Date: Wed, 29 Mar 2023 08:50:30 +0000 Subject: [PATCH 2/3] change hook approach --- diam/message.go | 32 +++++++++++++++----------------- diam/server.go | 44 +++++++++++++++++++++----------------------- 2 files changed, 36 insertions(+), 40 deletions(-) diff --git a/diam/message.go b/diam/message.go index 96ae6fe..176db0f 100644 --- a/diam/message.go +++ b/diam/message.go @@ -64,19 +64,24 @@ func readerBufferSlice(buf *bytes.Buffer, l int) []byte { // ReadMessage reads a binary stream from the reader and uses the given // dictionary to parse it. -func ReadMessage(reader io.Reader, dictionary *dict.Parser, hook *HeaderReaderHook) (*Message, *string, error) { +func ReadMessage(reader io.Reader, dictionary *dict.Parser, hook func(*Message) error) (*Message, error) { buf := newReaderBuffer() defer putReaderBuffer(buf) m := &Message{dictionary: dictionary} - cmd, stream, endpointId, err := m.readHeader(reader, buf, hook) + + if err := hook(m); err != nil { + return nil, err + } + + cmd, stream, err := m.readHeader(reader, buf) if err != nil { - return nil, endpointId, err + return nil, err } m.stream = stream if err = m.readBody(reader, buf, cmd, stream); err != nil { - return nil, endpointId, err + return nil, err } - return m, endpointId, nil + return m, nil } // MessageStream returns the stream #, the message was received on (when applicable) @@ -84,14 +89,7 @@ func (m *Message) MessageStream() uint { return m.stream } -func (m *Message) readHeader(r io.Reader, buf *bytes.Buffer, hook *HeaderReaderHook) (cmd *dict.Command, stream uint, endpointId *string, err error) { - if (hook != nil) { - endpointId, err = (*hook)(r) - if err != nil { - return nil, stream, nil, err - } - } - +func (m *Message) readHeader(r io.Reader, buf *bytes.Buffer) (cmd *dict.Command, stream uint, err error) { b := buf.Bytes()[:HeaderLength] msr, isMulti := r.(MultistreamReader) if isMulti { @@ -103,21 +101,21 @@ func (m *Message) readHeader(r io.Reader, buf *bytes.Buffer, hook *HeaderReaderH _, err = io.ReadFull(r, b) } if err != nil { - return nil, stream, nil, err + return nil, stream, err } m.Header, err = DecodeHeader(b) if err != nil { - return nil, stream, endpointId, err + return nil, stream, err } cmd, err = m.Dictionary().FindCommand( m.Header.ApplicationID, m.Header.CommandCode, ) if err != nil { - return nil, stream, endpointId, err + return nil, stream, err } - return cmd, stream, endpointId, nil + return cmd, stream, nil } func (m *Message) readBody(r io.Reader, buf *bytes.Buffer, cmd *dict.Command, stream uint) error { diff --git a/diam/server.go b/diam/server.go index dd42689..4293260 100644 --- a/diam/server.go +++ b/diam/server.go @@ -43,7 +43,6 @@ type Conn interface { Context() context.Context // Returns the internal context SetContext(ctx context.Context) // Stores a new context Connection() net.Conn // Returns network connection - ProxyHeaderVpcEndpointId() *string // Returns the VPC Endpoint ID fetched by the header hook } // The CloseNotifier interface is implemented by Conns which @@ -88,10 +87,9 @@ type conn struct { tlsState *tls.ConnectionState // or nil when not using TLS writer *response // the diam.Conn exposed to handlers - mu sync.Mutex // guards the following - closeNotifyc chan struct{} - clientGone bool - vpcEndpointId *string // the AWS VPC endpoint ID extracted from the proxy header + mu sync.Mutex // guards the following + closeNotifyc chan struct{} + clientGone bool } func (c *conn) closeNotify() <-chan struct{} { @@ -164,17 +162,21 @@ func (c *conn) readMessage() (m *Message, err error) { if c.server.ReadTimeout > 0 { c.rwc.SetReadDeadline(time.Now().Add(c.server.ReadTimeout)) } + + wrappedMethod := func(m *Message) error { + return c.server.ReadMessageHook(c.writer, m) + } + if msc, isMulti := c.rwc.(MultistreamConn); isMulti { // If it's a multi-stream association - reset the stream to "undefined" prior to reading next message msc.ResetCurrentStream() - m, vpcEndpointId, err = ReadMessage(msc, c.dictionary(), c.server.HeaderHook) // MultistreamConn has it's own buffering + m, err = ReadMessage(msc, c.dictionary(), wrappedMethod) // MultistreamConn has it's own buffering } else { - m, vpcEndpointId, err = ReadMessage(c.buf.Reader, c.dictionary(), c.server.HeaderHook) + m, err = ReadMessage(c.buf.Reader, c.dictionary(), wrappedMethod) } if err != nil { return nil, err } - c.vpcEndpointId = vpcEndpointId return m, nil } @@ -341,10 +343,6 @@ func (w *response) Connection() net.Conn { return w.conn.rwc } -func (w *response) ProxyHeaderVpcEndpointId() *string { - return w.conn.vpcEndpointId -} - // The HandlerFunc type is an adapter to allow the use of // ordinary functions as diameter handlers. If f is a function // with the appropriate signature, HandlerFunc(f) is a @@ -564,20 +562,20 @@ func Serve(l net.Listener, handler Handler) error { return srv.Serve(l) } -type HeaderReaderHook = func(io.Reader) (*string, error) +type ReadMessageHook = func(Conn, *Message) error // A Server defines parameters for running a diameter server. type Server struct { - Network string // network of the address - empty string defaults to tcp - Addr string // address to listen on, ":3868" if empty - Handler Handler // handler to invoke, DefaultServeMux if nil - Dict *dict.Parser // diameter dictionaries for this server - ReadTimeout time.Duration // maximum duration before timing out read of the request - WriteTimeout time.Duration // maximum duration before timing out write of the response - TLSConfig *tls.Config // optional TLS config, used by ListenAndServeTLS - LocalAddr net.Addr // optional Local Address to bind dailer's (Dail...) socket to - HeaderHook -}/ + Network string // network of the address - empty string defaults to tcp + Addr string // address to listen on, ":3868" if empty + Handler Handler // handler to invoke, DefaultServeMux if nil + Dict *dict.Parser // diameter dictionaries for this server + ReadTimeout time.Duration // maximum duration before timing out read of the request + WriteTimeout time.Duration // maximum duration before timing out write of the response + TLSConfig *tls.Config // optional TLS config, used by ListenAndServeTLS + LocalAddr net.Addr // optional Local Address to bind dailer's (Dail...) socket to + ReadMessageHook ReadMessageHook +} // serverHandler delegates to either the server's Handler or DefaultServeMux. type serverHandler struct { From df101f19fbcd3cc73b1ea78f9e13d3cc90591977 Mon Sep 17 00:00:00 2001 From: Ahmed Anas Date: Wed, 29 Mar 2023 08:54:12 +0000 Subject: [PATCH 3/3] nil handling --- diam/server.go | 22 +++++++++++++--------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/diam/server.go b/diam/server.go index 4293260..0e44d9c 100644 --- a/diam/server.go +++ b/diam/server.go @@ -164,6 +164,10 @@ func (c *conn) readMessage() (m *Message, err error) { } wrappedMethod := func(m *Message) error { + if c.server.ReadMessageHook == nil { + return nil + } + return c.server.ReadMessageHook(c.writer, m) } @@ -566,15 +570,15 @@ type ReadMessageHook = func(Conn, *Message) error // A Server defines parameters for running a diameter server. type Server struct { - Network string // network of the address - empty string defaults to tcp - Addr string // address to listen on, ":3868" if empty - Handler Handler // handler to invoke, DefaultServeMux if nil - Dict *dict.Parser // diameter dictionaries for this server - ReadTimeout time.Duration // maximum duration before timing out read of the request - WriteTimeout time.Duration // maximum duration before timing out write of the response - TLSConfig *tls.Config // optional TLS config, used by ListenAndServeTLS - LocalAddr net.Addr // optional Local Address to bind dailer's (Dail...) socket to - ReadMessageHook ReadMessageHook + Network string // network of the address - empty string defaults to tcp + Addr string // address to listen on, ":3868" if empty + Handler Handler // handler to invoke, DefaultServeMux if nil + Dict *dict.Parser // diameter dictionaries for this server + ReadTimeout time.Duration // maximum duration before timing out read of the request + WriteTimeout time.Duration // maximum duration before timing out write of the response + TLSConfig *tls.Config // optional TLS config, used by ListenAndServeTLS + LocalAddr net.Addr // optional Local Address to bind dailer's (Dail...) socket to + ReadMessageHook ReadMessageHook // optional Called right before ReadMessage method. } // serverHandler delegates to either the server's Handler or DefaultServeMux.