diff --git a/nsqd/channel.go b/nsqd/channel.go index 3fc931c38..62175b622 100644 --- a/nsqd/channel.go +++ b/nsqd/channel.go @@ -47,9 +47,11 @@ type Channel struct { backend BackendQueue - memoryMsgChan chan *Message - exitFlag int32 - exitMutex sync.RWMutex + zoneLocalMsgChan chan *Message + regionLocalMsgChan chan *Message + memoryMsgChan chan *Message + exitFlag int32 + exitMutex sync.RWMutex // state tracking clients map[int64]Consumer @@ -82,6 +84,14 @@ func NewChannel(topicName string, channelName string, nsqd *NSQD, deleteCallback: deleteCallback, nsqd: nsqd, } + + if nsqd.getOpts().TopologyRegion != "" { + c.regionLocalMsgChan = make(chan *Message, 0) + } + if nsqd.getOpts().TopologyZone != "" { + c.zoneLocalMsgChan = make(chan *Message, 0) + } + // create mem-queue only if size > 0 (do not use unbuffered chan) if nsqd.getOpts().MemQueueSize > 0 { c.memoryMsgChan = make(chan *Message, nsqd.getOpts().MemQueueSize) @@ -302,16 +312,27 @@ func (c *Channel) PutMessage(m *Message) error { } func (c *Channel) put(m *Message) error { + select { + case c.zoneLocalMsgChan <- m: + return nil + default: + } + select { + case c.regionLocalMsgChan <- m: + return nil + default: + } select { case c.memoryMsgChan <- m: + return nil default: - err := writeMessageToBackend(m, c.backend) - c.nsqd.SetHealth(err) - if err != nil { - c.nsqd.logf(LOG_ERROR, "CHANNEL(%s): failed to write message to backend - %s", - c.name, err) - return err - } + } + err := writeMessageToBackend(m, c.backend) + c.nsqd.SetHealth(err) + if err != nil { + c.nsqd.logf(LOG_ERROR, "CHANNEL(%s): failed to write message to backend - %s", + c.name, err) + return err } return nil } diff --git a/nsqd/protocol_v2.go b/nsqd/protocol_v2.go index 64a48826b..2dc99b63e 100644 --- a/nsqd/protocol_v2.go +++ b/nsqd/protocol_v2.go @@ -203,7 +203,7 @@ func (p *protocolV2) Exec(client *clientV2, params [][]byte) ([]byte, error) { func (p *protocolV2) messagePump(client *clientV2, startedChan chan bool) { var err error - var memoryMsgChan chan *Message + var zoneMsgChan, regionMsgChan, memoryMsgChan chan *Message var backendMsgChan <-chan []byte var subChannel *Channel // NOTE: `flusherChan` is used to bound message latency for @@ -211,6 +211,7 @@ func (p *protocolV2) messagePump(client *clientV2, startedChan chan bool) { // with >1 clients having >1 RDY counts var flusherChan <-chan time.Time var sampleRate int32 + var regionLocal, zoneLocal bool subEventChan := client.SubEventChan identifyEventChan := client.IdentifyEventChan @@ -232,9 +233,13 @@ func (p *protocolV2) messagePump(client *clientV2, startedChan chan bool) { close(startedChan) for { + var b []byte + var msg *Message if subChannel == nil || !client.IsReadyForMessages() { // the client is not ready to receive messages... memoryMsgChan = nil + regionMsgChan = nil + zoneMsgChan = nil backendMsgChan = nil flusherChan = nil // force flush @@ -249,12 +254,24 @@ func (p *protocolV2) messagePump(client *clientV2, startedChan chan bool) { // last iteration we flushed... // do not select on the flusher ticker channel memoryMsgChan = subChannel.memoryMsgChan + if zoneLocal { + zoneMsgChan = subChannel.zoneLocalMsgChan + } + if regionLocal { + regionMsgChan = subChannel.regionLocalMsgChan + } backendMsgChan = subChannel.backend.ReadChan() flusherChan = nil } else { // we're buffered (if there isn't any more data we should flush)... // select on the flusher ticker channel, too memoryMsgChan = subChannel.memoryMsgChan + if zoneLocal { + zoneMsgChan = subChannel.zoneLocalMsgChan + } + if regionLocal { + regionMsgChan = subChannel.regionLocalMsgChan + } backendMsgChan = subChannel.backend.ReadChan() flusherChan = outputBufferTicker.C } @@ -296,36 +313,37 @@ func (p *protocolV2) messagePump(client *clientV2, startedChan chan bool) { } msgTimeout = identifyData.MsgTimeout + if identifyData.TopologyZone == p.nsqd.getOpts().TopologyZone { + zoneLocal = true + } + if identifyData.TopologyRegion == p.nsqd.getOpts().TopologyRegion { + regionLocal = true + } case <-heartbeatChan: err = p.Send(client, frameTypeResponse, heartbeatBytes) if err != nil { goto exit } - case b := <-backendMsgChan: - if sampleRate > 0 && rand.Int31n(100) > sampleRate { - continue - } - - msg, err := decodeMessage(b) + case b = <-backendMsgChan: + // decodeMessage then handle 'msg' + case msg = <-zoneMsgChan: + case msg = <-regionMsgChan: + case msg = <-memoryMsgChan: + case <-client.ExitChan: + goto exit + } + if len(b) != 0 { + msg, err = decodeMessage(b) if err != nil { p.nsqd.logf(LOG_ERROR, "failed to decode message - %s", err) continue } - msg.Attempts++ - - subChannel.StartInFlightTimeout(msg, client.ID, msgTimeout) - client.SendingMessage() - err = p.SendMessage(client, msg) - if err != nil { - goto exit - } - flushed = false - case msg := <-memoryMsgChan: + } + if msg != nil { if sampleRate > 0 && rand.Int31n(100) > sampleRate { continue } msg.Attempts++ - subChannel.StartInFlightTimeout(msg, client.ID, msgTimeout) client.SendingMessage() err = p.SendMessage(client, msg) @@ -333,9 +351,8 @@ func (p *protocolV2) messagePump(client *clientV2, startedChan chan bool) { goto exit } flushed = false - case <-client.ExitChan: - goto exit } + } exit: diff --git a/nsqd/protocol_v2_test.go b/nsqd/protocol_v2_test.go index 5ff8ad9d5..390339e9e 100644 --- a/nsqd/protocol_v2_test.go +++ b/nsqd/protocol_v2_test.go @@ -210,6 +210,81 @@ func TestMultipleConsumerV2(t *testing.T) { test.Equal(t, uint16(1), msgOut.Attempts) } +// TestSameZoneConsumerV2 tests that a published message goes to same-zone consumer first +// if it's message pump is waiting +func TestSameZoneConsumerV2(t *testing.T) { + opts := NewOptions() + opts.Logger = test.NewTestLogger(t) + opts.ClientTimeout = 60 * time.Second + opts.TopologyRegion = "region" + opts.TopologyZone = "zone" + tcpAddr, _, nsqd := mustStartNSQD(opts) + defer os.RemoveAll(opts.DataPath) + defer nsqd.Exit() + + topicName := "test_zone_v2" + strconv.Itoa(int(time.Now().Unix())) + topic := nsqd.GetTopic(topicName) + msg := NewMessage(topic.GenerateID(), []byte("test body")) + topic.GetChannel("ch") + + var sameZone, diffZone int64 + var exiting int32 + done := make(chan bool, 21) + for _, zone := range []string{"zone", "zone", "zone2", "zone2"} { + zone := zone + conn, err := mustConnectNSQD(tcpAddr) + test.Nil(t, err) + defer conn.Close() + + identify(t, conn, map[string]interface{}{"topology_zone": zone}, frameTypeResponse) + sub(t, conn, topicName, "ch") + + _, err = nsq.Ready(10).WriteTo(conn) + test.Nil(t, err) + + go func(c net.Conn, zone string) { + for { + resp, err := nsq.ReadResponse(c) + if atomic.LoadInt32(&exiting) == 1 { + return + } + test.Nil(t, err) + _, data, err := nsq.UnpackResponse(resp) + test.Nil(t, err) + _, err = decodeMessage(data) + test.Nil(t, err) + if zone == "zone" { + atomic.AddInt64(&sameZone, 1) + } else { + atomic.AddInt64(&diffZone, 1) + } + done <- true + } + }(conn, zone) + } + + // first 20 messages go to same zone (each has RDY 10) + // next message goes to global memoryChan (All consumers) + for i := 0; i < 21; i++ { + topic.PutMessage(msg) + if i%2 == 0 { + // sleep long enough for messagePump to wait again + time.Sleep(time.Millisecond) + } + } + var doneCount int64 + for _ = range done { + doneCount += 1 + if doneCount == 21 { + break + } + } + t.Logf("got same zone %d diffZone %d", sameZone, diffZone) + atomic.StoreInt32(&exiting, 1) + test.Equal(t, int64(20), sameZone) + test.Equal(t, int64(1), diffZone) +} + func TestClientTimeout(t *testing.T) { topicName := "test_client_timeout_v2" + strconv.Itoa(int(time.Now().Unix()))