diff --git a/ipc/uds/highlevel/chunk/chunk.go b/ipc/uds/highlevel/chunk/chunk.go index 9206916..6d32cbf 100644 --- a/ipc/uds/highlevel/chunk/chunk.go +++ b/ipc/uds/highlevel/chunk/chunk.go @@ -89,9 +89,10 @@ type Client struct { writeVarInt []byte readVarInt []byte - maxSize int64 - biggestReadMsgSize int - biggestWriteMsgSize int + maxSize int64 + + closeOnce sync.Once + closed chan struct{} readMu, writeMu sync.Mutex } @@ -125,10 +126,12 @@ func New(rwc io.ReadWriteCloser, options ...Option) (*Client, error) { client = &Client{ rwc: rwc, writeVarInt: make([]byte, 8), + closed: make(chan struct{}), } default: return nil, fmt.Errorf("rwc was not a *uds.Client or *uds.Server, was %T", rwc) } + for _, o := range options { o(client) } @@ -148,8 +151,18 @@ func New(rwc io.ReadWriteCloser, options ...Option) (*Client, error) { return client, nil } +// ClosedSignal returns a channel that will be closed when this Client becomes closed. +func (c *Client) ClosedSignal() chan struct{} { + return c.closed +} + // Close closes the underlying io.ReadWriteCloser. func (c *Client) Close() error { + defer c.closeOnce.Do( + func() { + close(c.closed) + }, + ) return c.rwc.Close() } @@ -164,20 +177,20 @@ const ( oneMiB = 1000 * 1024 ) -// Read reads the next message from the socket. +// Read reads the next message from the socket. Any error closes the client. func (c *Client) Read() (*[]byte, error) { c.readMu.Lock() defer c.readMu.Unlock() size, err := binary.ReadVarint(c.rwc.(io.ByteReader)) if err != nil { - c.rwc.Close() + c.Close() return nil, err } if c.maxSize > 0 { if size > c.maxSize { - c.rwc.Close() + c.Close() return nil, fmt.Errorf("message is larger than maximum size allowed") } } @@ -193,7 +206,7 @@ func (c *Client) Read() (*[]byte, error) { return b, nil } -// Write writes b as a chunk into the socket. +// Write writes b as a chunk into the socket. Any error closes the client. func (c *Client) Write(b []byte) error { c.writeMu.Lock() defer c.writeMu.Unlock() @@ -211,7 +224,7 @@ func (c *Client) Write(b []byte) error { n := binary.PutVarint(c.writeVarInt, int64(len(b))) _, err := c.rwc.Write(c.writeVarInt[:n]) if err != nil { - c.rwc.Close() + c.Close() return err } n, _ = c.rwc.Write(b)