Skip to content
This repository has been archived by the owner on Mar 11, 2020. It is now read-only.

pub/sub didn't work because of close clientChan #24

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 16 additions & 1 deletion auto.go
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,19 @@ func hashValueReply(v HashValue) (*MultiBulkReply, error) {
return MultiBulkFromMap(m), nil
}

func indexValueReply(v map[int][]byte) (*MultiBulkReply, error) {
fmt.Println(v)
fmt.Println(len(v))
m := make([]interface{}, len(v)*2)
i := 0
for k, v := range v {
m[i] = v
m[i+1] = k
i += 2
}
return &MultiBulkReply{values: m}, nil
}

func (srv *Server) createReply(r *Request, val interface{}) (ReplyWriter, error) {
Debugf("CREATE REPLY: %T", val)
switch v := val.(type) {
Expand All @@ -139,6 +152,8 @@ func (srv *Server) createReply(r *Request, val interface{}) (ReplyWriter, error)
return hashValueReply(v)
case map[string][]byte:
return hashValueReply(v)
case map[int][]byte:
return indexValueReply(v)
case map[string]interface{}:
return MultiBulkFromMap(v), nil
case int:
Expand All @@ -156,7 +171,7 @@ func (srv *Server) createReply(r *Request, val interface{}) (ReplyWriter, error)
case *MultiChannelWriter:
println("New client")
for _, mcw := range v.Chans {
mcw.clientChan = r.ClientChan
mcw.ClientChan = r.ClientChan
}
return v, nil
default:
Expand Down
21 changes: 10 additions & 11 deletions parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import (
"bufio"
"fmt"
"io"
"io/ioutil"
"strconv"
"strings"
)

Expand All @@ -13,6 +13,7 @@ func parseRequest(conn io.ReadCloser) (*Request, error) {
// first line of redis request should be:
// *<number of arguments>CRLF
line, err := r.ReadString('\n')
// fmt.Println(line)
if err != nil {
return nil, err
}
Expand All @@ -21,7 +22,8 @@ func parseRequest(conn io.ReadCloser) (*Request, error) {

// Multiline request:
if line[0] == '*' {
if _, err := fmt.Sscanf(line, "*%d\r", &argsCount); err != nil {
argsCount, err = strconv.Atoi(strings.Trim(line, "* \r\n"))
if err != nil {
return nil, malformed("*<numberOfArguments>", line)
}
// All next lines are pairs of:
Expand All @@ -32,14 +34,12 @@ func parseRequest(conn io.ReadCloser) (*Request, error) {
if err != nil {
return nil, err
}

args := make([][]byte, argsCount-1)
for i := 0; i < argsCount-1; i += 1 {
if args[i], err = readArgument(r); err != nil {
return nil, err
}
}

return &Request{
Name: strings.ToLower(string(firstArg)),
Args: args,
Expand All @@ -56,6 +56,7 @@ func parseRequest(conn io.ReadCloser) (*Request, error) {
args = append(args, []byte(arg))
}
}
fmt.Println(strings.ToLower(string(fields[0])))
return &Request{
Name: strings.ToLower(string(fields[0])),
Args: args,
Expand All @@ -71,19 +72,17 @@ func readArgument(r *bufio.Reader) ([]byte, error) {
return nil, malformed("$<argumentLength>", line)
}
var argSize int
if _, err := fmt.Sscanf(line, "$%d\r", &argSize); err != nil {
argSize, err = strconv.Atoi(strings.Trim(line, "$ \r\n"))
if err != nil {
return nil, malformed("$<argumentSize>", line)
}

// I think int is safe here as the max length of request
// should be less then max int value?
data, err := ioutil.ReadAll(io.LimitReader(r, int64(argSize)))
data := make([]byte, argSize)
n, err := io.ReadFull(r, data)
if err != nil {
return nil, err
}

if len(data) != argSize {
return nil, malformedLength(argSize, len(data))
return nil, malformedLength(argSize, n)
}

// Now check for trailing CR
Expand Down
4 changes: 2 additions & 2 deletions reply.go
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ func (c *MultiChannelWriter) WriteTo(w io.Writer) (n int64, err error) {
type ChannelWriter struct {
FirstReply []interface{}
Channel chan []interface{}
clientChan chan struct{}
ClientChan chan struct{}
}

func (c *ChannelWriter) WriteTo(w io.Writer) (int64, error) {
Expand All @@ -191,7 +191,7 @@ func (c *ChannelWriter) WriteTo(w io.Writer) (int64, error) {

for {
select {
case <-c.clientChan:
case <-c.ClientChan:
return totalBytes, err
case reply := <-c.Channel:
if reply == nil {
Expand Down
51 changes: 50 additions & 1 deletion server.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ func (srv *Server) ServeClient(conn net.Conn) (err error) {
clientChan := make(chan struct{})

// Read on `conn` in order to detect client disconnect
go func() {
defer func() {
// Close chan in order to trigger eventual selects
defer close(clientChan)
defer Debugf("Client disconnected")
Expand Down Expand Up @@ -106,6 +106,55 @@ func (srv *Server) ServeClient(conn net.Conn) (err error) {
return nil
}

func (srv *Server) ServeReplClient(conn net.Conn) (err error) {
defer func() {
if err != nil {
fmt.Fprintf(conn, "-%s\n", err)
}
conn.Close()
}()

clientChan := make(chan struct{})

// Read on `conn` in order to detect client disconnect
defer func() {
// Close chan in order to trigger eventual selects
defer close(clientChan)
defer Debugf("Client disconnected")
// FIXME: move conn within the request.
if false {
io.Copy(ioutil.Discard, conn)
}
}()

var clientAddr string

switch co := conn.(type) {
case *net.UnixConn:
f, err := conn.(*net.UnixConn).File()
if err != nil {
return err
}
clientAddr = f.Name()
default:
clientAddr = co.RemoteAddr().String()
}

for {
request, err := parseRequest(conn)
if err != nil {
return err
}
request.Host = clientAddr
request.ClientChan = clientChan
_, err = srv.Apply(request)
if err != nil {
return err
}
}
return nil
}

func NewServer(c *Config) (*Server, error) {
srv := &Server{
Proto: c.proto,
Expand Down