diff --git a/nsqadmin/http_test.go b/nsqadmin/http_test.go index c0081bbbf..b415d808a 100644 --- a/nsqadmin/http_test.go +++ b/nsqadmin/http_test.go @@ -253,7 +253,8 @@ func TestHTTPChannelGET(t *testing.T) { defer nsqadmin1.Exit() topicName := "test_channel_get" + strconv.Itoa(int(time.Now().Unix())) - topic := nsqds[0].GetOrCreateTopic(topicName) + topic, err := nsqds[0].GetOrCreateTopic(topicName) + test.Nil(t, err) topic.GetOrCreateChannel("ch") time.Sleep(100 * time.Millisecond) @@ -292,7 +293,8 @@ func TestHTTPNodesSingleGET(t *testing.T) { defer nsqadmin1.Exit() topicName := "test_nodes_single_get" + strconv.Itoa(int(time.Now().Unix())) - topic := nsqds[0].GetOrCreateTopic(topicName) + topic, err := nsqds[0].GetOrCreateTopic(topicName) + test.Nil(t, err) topic.GetOrCreateChannel("ch") time.Sleep(100 * time.Millisecond) @@ -419,7 +421,8 @@ func TestHTTPDeleteChannelPOST(t *testing.T) { defer nsqadmin1.Exit() topicName := "test_delete_channel_post" + strconv.Itoa(int(time.Now().Unix())) - topic := nsqds[0].GetOrCreateTopic(topicName) + topic, err := nsqds[0].GetOrCreateTopic(topicName) + test.Nil(t, err) topic.GetOrCreateChannel("ch") time.Sleep(100 * time.Millisecond) @@ -474,7 +477,8 @@ func TestHTTPPauseChannelPOST(t *testing.T) { defer nsqadmin1.Exit() topicName := "test_pause_channel_post" + strconv.Itoa(int(time.Now().Unix())) - topic := nsqds[0].GetOrCreateTopic(topicName) + topic, err := nsqds[0].GetOrCreateTopic(topicName) + test.Nil(t, err) topic.GetOrCreateChannel("ch") time.Sleep(100 * time.Millisecond) @@ -509,7 +513,8 @@ func TestHTTPEmptyTopicPOST(t *testing.T) { defer nsqadmin1.Exit() topicName := "test_empty_topic_post" + strconv.Itoa(int(time.Now().Unix())) - topic := nsqds[0].GetOrCreateTopic(topicName) + topic, err := nsqds[0].GetOrCreateTopic(topicName) + test.Nil(t, err) topic.PutMessage(nsqd.NewMessage(nsqd.MessageID{}, []byte("1234"))) test.Equal(t, int64(1), topic.Depth()) time.Sleep(100 * time.Millisecond) @@ -537,7 +542,8 @@ func TestHTTPEmptyChannelPOST(t *testing.T) { defer nsqadmin1.Exit() topicName := "test_empty_channel_post" + strconv.Itoa(int(time.Now().Unix())) - topic := nsqds[0].GetOrCreateTopic(topicName) + topic, err := nsqds[0].GetOrCreateTopic(topicName) + test.Nil(t, err) channel := topic.GetOrCreateChannel("ch") channel.PutMessage(nsqd.NewMessage(nsqd.MessageID{}, []byte("1234"))) diff --git a/nsqd/channel_test.go b/nsqd/channel_test.go index 561bad9c8..775519a85 100644 --- a/nsqd/channel_test.go +++ b/nsqd/channel_test.go @@ -21,8 +21,8 @@ func TestPutMessage(t *testing.T) { defer nsqd.Exit() topicName := "test_put_message" + strconv.Itoa(int(time.Now().Unix())) - topic := nsqd.GetOrCreateTopic(topicName) - channel1 := topic.GetOrCreateChannel("ch") + topic, _ := nsqd.GetOrCreateTopic(topicName) + channel1, _ := topic.GetOrCreateChannel("ch") var id MessageID msg := NewMessage(id, []byte("test")) @@ -42,9 +42,9 @@ func TestPutMessage2Chan(t *testing.T) { defer nsqd.Exit() topicName := "test_put_message_2chan" + strconv.Itoa(int(time.Now().Unix())) - topic := nsqd.GetOrCreateTopic(topicName) - channel1 := topic.GetOrCreateChannel("ch1") - channel2 := topic.GetOrCreateChannel("ch2") + topic, _ := nsqd.GetOrCreateTopic(topicName) + channel1, _ := topic.GetOrCreateChannel("ch1") + channel2, _ := topic.GetOrCreateChannel("ch2") var id MessageID msg := NewMessage(id, []byte("test")) @@ -71,8 +71,8 @@ func TestInFlightWorker(t *testing.T) { defer nsqd.Exit() topicName := "test_in_flight_worker" + strconv.Itoa(int(time.Now().Unix())) - topic := nsqd.GetOrCreateTopic(topicName) - channel := topic.GetOrCreateChannel("channel") + topic, _ := nsqd.GetOrCreateTopic(topicName) + channel, _ := topic.GetOrCreateChannel("channel") for i := 0; i < count; i++ { msg := NewMessage(topic.GenerateID(), []byte("test")) @@ -112,8 +112,8 @@ func TestChannelEmpty(t *testing.T) { defer nsqd.Exit() topicName := "test_channel_empty" + strconv.Itoa(int(time.Now().Unix())) - topic := nsqd.GetOrCreateTopic(topicName) - channel := topic.GetOrCreateChannel("channel") + topic, _ := nsqd.GetOrCreateTopic(topicName) + channel, _ := topic.GetOrCreateChannel("channel") msgs := make([]*Message, 0, 25) for i := 0; i < 25; i++ { @@ -148,8 +148,8 @@ func TestChannelEmptyConsumer(t *testing.T) { defer conn.Close() topicName := "test_channel_empty" + strconv.Itoa(int(time.Now().Unix())) - topic := nsqd.GetOrCreateTopic(topicName) - channel := topic.GetOrCreateChannel("channel") + topic, _ := nsqd.GetOrCreateTopic(topicName) + channel, _ := topic.GetOrCreateChannel("channel") client := newClientV2(0, conn, nsqd) client.SetReadyCount(25) err := channel.AddClient(client.ID, client) @@ -186,8 +186,8 @@ func TestMaxChannelConsumers(t *testing.T) { defer conn.Close() topicName := "test_max_channel_consumers" + strconv.Itoa(int(time.Now().Unix())) - topic := nsqd.GetOrCreateTopic(topicName) - channel := topic.GetOrCreateChannel("channel") + topic, _ := nsqd.GetOrCreateTopic(topicName) + channel, _ := topic.GetOrCreateChannel("channel") client1 := newClientV2(1, conn, nsqd) client1.SetReadyCount(25) @@ -209,9 +209,9 @@ func TestChannelHealth(t *testing.T) { defer os.RemoveAll(opts.DataPath) defer nsqd.Exit() - topic := nsqd.GetOrCreateTopic("test") + topic, _ := nsqd.GetOrCreateTopic("test") - channel := topic.GetOrCreateChannel("channel") + channel, _ := topic.GetOrCreateChannel("channel") channel.backend = &errorBackendQueue{} @@ -258,8 +258,8 @@ func TestChannelDraining(t *testing.T) { defer nsqd.Exit() topicName := "test_drain_channel" + strconv.Itoa(int(time.Now().Unix())) - topic := nsqd.GetOrCreateTopic(topicName) - channel1 := topic.GetOrCreateChannel("ch") + topic, _ := nsqd.GetOrCreateTopic(topicName) + channel1, _ := topic.GetOrCreateChannel("ch") msg := NewMessage(topic.GenerateID(), []byte("test")) topic.PutMessage(msg) diff --git a/nsqd/http.go b/nsqd/http.go index 6c1b89c11..ae9572a34 100644 --- a/nsqd/http.go +++ b/nsqd/http.go @@ -201,8 +201,8 @@ func (s *httpServer) getTopicFromQuery(req *http.Request) (url.Values, *Topic, e if !protocol.IsValidTopicName(topicName) { return nil, nil, http_api.Err{400, "INVALID_TOPIC"} } - topic := s.nsqd.GetOrCreateTopic(topicName) - if topic == nil { + topic, err := s.nsqd.GetOrCreateTopic(topicName) + if err != nil { return nil, nil, http_api.Err{503, "EXITING"} } @@ -416,8 +416,8 @@ func (s *httpServer) doCreateChannel(w http.ResponseWriter, req *http.Request, p if err != nil { return nil, err } - ch := topic.GetOrCreateChannel(channelName) - if ch == nil { + _, err = topic.GetOrCreateChannel(channelName) + if err != nil { return nil, http_api.Err{503, "EXITING"} } return nil, nil diff --git a/nsqd/http_test.go b/nsqd/http_test.go index 5a77d3ce4..94ac5e580 100644 --- a/nsqd/http_test.go +++ b/nsqd/http_test.go @@ -46,7 +46,7 @@ func TestHTTPpub(t *testing.T) { defer nsqd.Exit() topicName := "test_http_pub" + strconv.Itoa(int(time.Now().Unix())) - topic := nsqd.GetOrCreateTopic(topicName) + topic, _ := nsqd.GetOrCreateTopic(topicName) buf := bytes.NewBuffer([]byte("test message")) url := fmt.Sprintf("http://%s/pub?topic=%s", httpAddr, topicName) @@ -69,7 +69,7 @@ func TestHTTPpubEmpty(t *testing.T) { defer nsqd.Exit() topicName := "test_http_pub_empty" + strconv.Itoa(int(time.Now().Unix())) - topic := nsqd.GetOrCreateTopic(topicName) + topic, _ := nsqd.GetOrCreateTopic(topicName) buf := bytes.NewBuffer([]byte("")) url := fmt.Sprintf("http://%s/pub?topic=%s", httpAddr, topicName) @@ -93,7 +93,7 @@ func TestHTTPmpub(t *testing.T) { defer nsqd.Exit() topicName := "test_http_mpub" + strconv.Itoa(int(time.Now().Unix())) - topic := nsqd.GetOrCreateTopic(topicName) + topic, _ := nsqd.GetOrCreateTopic(topicName) msg := []byte("test message") msgs := make([][]byte, 4) @@ -122,7 +122,7 @@ func TestHTTPmpubEmpty(t *testing.T) { defer nsqd.Exit() topicName := "test_http_mpub_empty" + strconv.Itoa(int(time.Now().Unix())) - topic := nsqd.GetOrCreateTopic(topicName) + topic, _ := nsqd.GetOrCreateTopic(topicName) msg := []byte("test message") msgs := make([][]byte, 4) @@ -153,7 +153,7 @@ func TestHTTPmpubBinary(t *testing.T) { defer nsqd.Exit() topicName := "test_http_mpub_bin" + strconv.Itoa(int(time.Now().Unix())) - topic := nsqd.GetOrCreateTopic(topicName) + topic, _ := nsqd.GetOrCreateTopic(topicName) mpub := make([][]byte, 5) for i := range mpub { @@ -182,7 +182,7 @@ func TestHTTPmpubForNonNormalizedBinaryParam(t *testing.T) { defer nsqd.Exit() topicName := "test_http_mpub_bin" + strconv.Itoa(int(time.Now().Unix())) - topic := nsqd.GetOrCreateTopic(topicName) + topic, _ := nsqd.GetOrCreateTopic(topicName) mpub := make([][]byte, 5) for i := range mpub { @@ -211,8 +211,8 @@ func TestHTTPpubDefer(t *testing.T) { defer nsqd.Exit() topicName := "test_http_pub_defer" + strconv.Itoa(int(time.Now().Unix())) - topic := nsqd.GetOrCreateTopic(topicName) - ch := topic.GetOrCreateChannel("ch") + topic, _ := nsqd.GetOrCreateTopic(topicName) + ch, _ := topic.GetOrCreateChannel("ch") buf := bytes.NewBuffer([]byte("test message")) url := fmt.Sprintf("http://%s/pub?topic=%s&defer=%d", httpAddr, topicName, 1000) @@ -242,7 +242,7 @@ func TestHTTPSRequire(t *testing.T) { defer nsqd.Exit() topicName := "test_http_pub_req" + strconv.Itoa(int(time.Now().Unix())) - topic := nsqd.GetOrCreateTopic(topicName) + topic, _ := nsqd.GetOrCreateTopic(topicName) buf := bytes.NewBuffer([]byte("test message")) url := fmt.Sprintf("http://%s/pub?topic=%s", httpAddr, topicName) @@ -289,7 +289,7 @@ func TestHTTPSRequireVerify(t *testing.T) { httpsAddr := nsqd.httpsListener.Addr().(*net.TCPAddr) topicName := "test_http_pub_req_verf" + strconv.Itoa(int(time.Now().Unix())) - topic := nsqd.GetOrCreateTopic(topicName) + topic, _ := nsqd.GetOrCreateTopic(topicName) // no cert buf := bytes.NewBuffer([]byte("test message")) @@ -353,7 +353,7 @@ func TestTLSRequireVerifyExceptHTTP(t *testing.T) { defer nsqd.Exit() topicName := "test_http_req_verf_except_http" + strconv.Itoa(int(time.Now().Unix())) - topic := nsqd.GetOrCreateTopic(topicName) + topic, _ := nsqd.GetOrCreateTopic(topicName) // no cert buf := bytes.NewBuffer([]byte("test message")) @@ -761,7 +761,7 @@ func TestEmptyChannel(t *testing.T) { test.Equal(t, 404, resp.StatusCode) test.HTTPError(t, resp, 404, "TOPIC_NOT_FOUND") - topic := nsqd.GetOrCreateTopic(topicName) + topic, _ := nsqd.GetOrCreateTopic(topicName) url = fmt.Sprintf("http://%s/channel/empty?topic=%s&channel=%s", httpAddr, topicName, channelName) resp, err = http.Post(url, "application/json", nil) diff --git a/nsqd/nsqd.go b/nsqd/nsqd.go index a3540e526..b7c047aad 100644 --- a/nsqd/nsqd.go +++ b/nsqd/nsqd.go @@ -329,8 +329,8 @@ func (n *NSQD) LoadMetadata() error { n.logf(LOG_WARN, "skipping creation of invalid topic %s", t.Name) continue } - topic := n.GetOrCreateTopic(t.Name) - if topic == nil { + topic, err := n.GetOrCreateTopic(t.Name) + if err != nil { n.logf(LOG_WARN, "skipping creation of topic, nsqd draining %s", t.Name) continue } @@ -342,8 +342,8 @@ func (n *NSQD) LoadMetadata() error { n.logf(LOG_WARN, "skipping creation of invalid channel %s", c.Name) continue } - channel := topic.GetOrCreateChannel(c.Name) - if c.Paused && channel != nil { + channel, err := topic.GetOrCreateChannel(c.Name) + if c.Paused && err != nil { channel.Pause() } } @@ -460,14 +460,14 @@ func (n *NSQD) Exit() { // GetOrCreateTopic performs a thread safe operation to get an existing topic or create a new one // -// The creation might fail if nsqd is draining -func (n *NSQD) GetOrCreateTopic(topicName string) *Topic { +// An error will be returned if nsqd is draining +func (n *NSQD) GetOrCreateTopic(topicName string) (*Topic, error) { // most likely, we already have this topic, so try read lock first. n.RLock() t, ok := n.topicMap[topicName] n.RUnlock() if ok { - return t + return t, nil } n.Lock() @@ -475,11 +475,11 @@ func (n *NSQD) GetOrCreateTopic(topicName string) *Topic { t, ok = n.topicMap[topicName] if ok { n.Unlock() - return t + return t, nil } if atomic.LoadInt32(&n.isDraining) == 1 { // don't create new topics when nsqd is draining - return nil + return nil, errors.New("nsqd draining") } deleteCallback := func(t *Topic) { @@ -506,7 +506,7 @@ func (n *NSQD) GetOrCreateTopic(topicName string) *Topic { // if loading metadata at startup, no lookupd connections yet, topic started after load if atomic.LoadInt32(&n.isLoading) == 1 { - return t + return t, nil } // if using lookupd, make a blocking call to get the topics, and immediately create them. @@ -529,7 +529,7 @@ func (n *NSQD) GetOrCreateTopic(topicName string) *Topic { // now that all channels are added, start topic messagePump t.Start() - return t + return t, nil } // GetExistingTopic gets a topic only if it exists diff --git a/nsqd/nsqd_test.go b/nsqd/nsqd_test.go index 59b97471c..dc81c9ab8 100644 --- a/nsqd/nsqd_test.go +++ b/nsqd/nsqd_test.go @@ -73,14 +73,14 @@ func TestStartup(t *testing.T) { atomic.StoreInt32(&nsqd.isLoading, 0) body := make([]byte, 256) - topic := nsqd.GetOrCreateTopic(topicName) + topic, _ := nsqd.GetOrCreateTopic(topicName) for i := 0; i < iterations; i++ { msg := NewMessage(topic.GenerateID(), body) topic.PutMessage(msg) } t.Logf("pulling from channel") - channel1 := topic.GetOrCreateChannel("ch1") + channel1, _ := topic.GetOrCreateChannel("ch1") t.Logf("read %d msgs", iterations/2) for i := 0; i < iterations/2; i++ { @@ -124,12 +124,12 @@ func TestStartup(t *testing.T) { doneExitChan <- 1 }() - topic = nsqd.GetOrCreateTopic(topicName) + topic, _ = nsqd.GetOrCreateTopic(topicName) // should be empty; channel should have drained everything count := topic.Depth() test.Equal(t, int64(0), count) - channel1 = topic.GetOrCreateChannel("ch1") + channel1, _ = topic.GetOrCreateChannel("ch1") for { if channel1.Depth() == int64(iterations/2) { @@ -176,8 +176,8 @@ func TestEphemeralTopicsAndChannels(t *testing.T) { }() body := []byte("an_ephemeral_message") - topic := nsqd.GetOrCreateTopic(topicName) - ephemeralChannel := topic.GetOrCreateChannel("ch1#ephemeral") + topic, _ := nsqd.GetOrCreateTopic(topicName) + ephemeralChannel, _ := topic.GetOrCreateChannel("ch1#ephemeral") client := newClientV2(0, nil, nsqd) err := ephemeralChannel.AddClient(client.ID, client) test.Equal(t, err, nil) @@ -215,8 +215,8 @@ func TestPauseMetadata(t *testing.T) { // avoid concurrency issue of async PersistMetadata() calls atomic.StoreInt32(&nsqd.isLoading, 1) topicName := "pause_metadata" + strconv.Itoa(int(time.Now().Unix())) - topic := nsqd.GetOrCreateTopic(topicName) - channel := topic.GetOrCreateChannel("ch") + topic, _ := nsqd.GetOrCreateTopic(topicName) + channel, _ := topic.GetOrCreateChannel("ch") atomic.StoreInt32(&nsqd.isLoading, 0) nsqd.PersistMetadata() diff --git a/nsqd/protocol_v2.go b/nsqd/protocol_v2.go index 41482fd70..1c1d673e3 100644 --- a/nsqd/protocol_v2.go +++ b/nsqd/protocol_v2.go @@ -616,15 +616,15 @@ func (p *protocolV2) SUB(client *clientV2, params [][]byte) ([]byte, error) { // Avoid adding a client to an ephemeral channel / topic which has started exiting. var channel *Channel for { - topic := p.nsqd.GetOrCreateTopic(topicName) - if topic == nil { + topic, err := p.nsqd.GetOrCreateTopic(topicName) + if err != nil { // topic creation might be blocked because of draining return nil, protocol.NewFatalClientErr(nil, "E_NSQD_DRAINING", fmt.Sprintf("SUB create channel %s:%s failed. nsqd is draining", topicName, channelName)) } - channel = topic.GetOrCreateChannel(channelName) - if channel == nil { + channel, _ = topic.GetOrCreateChannel(channelName) + if err != nil { // channel creation might be blocked because of draining return nil, protocol.NewFatalClientErr(nil, "E_TOPIC_DRAINING", fmt.Sprintf("SUB create channel %s:%s failed. Topic is draining with no messages left", @@ -814,8 +814,8 @@ func (p *protocolV2) PUB(client *clientV2, params [][]byte) ([]byte, error) { return nil, err } - topic := p.nsqd.GetOrCreateTopic(topicName) - if topic == nil { + topic, err := p.nsqd.GetOrCreateTopic(topicName) + if err != nil { return nil, protocol.NewFatalClientErr(err, "E_PUB_FAILED", "PUB failed. nsqd draining") } msg := NewMessage(topic.GenerateID(), messageBody) @@ -846,8 +846,8 @@ func (p *protocolV2) MPUB(client *clientV2, params [][]byte) ([]byte, error) { return nil, err } - topic := p.nsqd.GetOrCreateTopic(topicName) - if topic == nil { + topic, err := p.nsqd.GetOrCreateTopic(topicName) + if err != nil { return nil, protocol.NewFatalClientErr(err, "E_PUB_FAILED", "PUB failed. nsqd draining") } @@ -936,8 +936,8 @@ func (p *protocolV2) DPUB(client *clientV2, params [][]byte) ([]byte, error) { return nil, err } - topic := p.nsqd.GetOrCreateTopic(topicName) - if topic == nil { + topic, err := p.nsqd.GetOrCreateTopic(topicName) + if err != nil { return nil, protocol.NewFatalClientErr(err, "E_DPUB_FAILED", "DPUB failed. nsqd draining") } msg := NewMessage(topic.GenerateID(), messageBody) diff --git a/nsqd/protocol_v2_test.go b/nsqd/protocol_v2_test.go index 980076e0b..68d29537b 100644 --- a/nsqd/protocol_v2_test.go +++ b/nsqd/protocol_v2_test.go @@ -137,7 +137,7 @@ func TestBasicV2(t *testing.T) { defer nsqd.Exit() topicName := "test_v2" + strconv.Itoa(int(time.Now().Unix())) - topic := nsqd.GetOrCreateTopic(topicName) + topic, _ := nsqd.GetOrCreateTopic(topicName) msg := NewMessage(topic.GenerateID(), []byte("test body")) topic.PutMessage(msg) @@ -172,7 +172,7 @@ func TestMultipleConsumerV2(t *testing.T) { defer nsqd.Exit() topicName := "test_multiple_v2" + strconv.Itoa(int(time.Now().Unix())) - topic := nsqd.GetOrCreateTopic(topicName) + topic, _ := nsqd.GetOrCreateTopic(topicName) msg := NewMessage(topic.GenerateID(), []byte("test body")) topic.GetOrCreateChannel("ch1") topic.GetOrCreateChannel("ch2") @@ -382,9 +382,9 @@ func TestPausing(t *testing.T) { _, err = nsq.Ready(1).WriteTo(conn) test.Nil(t, err) - topic := nsqd.GetOrCreateTopic(topicName) + topic, _ := nsqd.GetOrCreateTopic(topicName) msg := NewMessage(topic.GenerateID(), []byte("test body")) - channel := topic.GetOrCreateChannel("ch") + channel, _ := topic.GetOrCreateChannel("ch") topic.PutMessage(msg) // receive the first message via the client, finish it, and send new RDY @@ -590,7 +590,8 @@ func TestDPUB(t *testing.T) { time.Sleep(25 * time.Millisecond) - ch := nsqd.GetOrCreateTopic(topicName).GetOrCreateChannel("ch") + topic, _ := nsqd.GetOrCreateTopic(topicName) + ch, _ := topic.GetOrCreateChannel("ch") ch.deferredMutex.Lock() numDef := len(ch.deferredMessages) ch.deferredMutex.Unlock() @@ -624,8 +625,8 @@ func TestTouch(t *testing.T) { identify(t, conn, nil, frameTypeResponse) sub(t, conn, topicName, "ch") - topic := nsqd.GetOrCreateTopic(topicName) - channel := topic.GetOrCreateChannel("ch") + topic, _ := nsqd.GetOrCreateTopic(topicName) + channel, _ := topic.GetOrCreateChannel("ch") msg := NewMessage(topic.GenerateID(), []byte("test body")) topic.PutMessage(msg) @@ -667,7 +668,7 @@ func TestMaxRdyCount(t *testing.T) { test.Nil(t, err) defer conn.Close() - topic := nsqd.GetOrCreateTopic(topicName) + topic, _ := nsqd.GetOrCreateTopic(topicName) msg := NewMessage(topic.GenerateID(), []byte("test body")) topic.PutMessage(msg) @@ -743,7 +744,7 @@ func TestOutputBuffering(t *testing.T) { outputBufferSize := 256 * 1024 outputBufferTimeout := 500 - topic := nsqd.GetOrCreateTopic(topicName) + topic, _ := nsqd.GetOrCreateTopic(topicName) msg := NewMessage(topic.GenerateID(), make([]byte, outputBufferSize-1024)) topic.PutMessage(msg) @@ -1139,7 +1140,7 @@ func TestSnappy(t *testing.T) { _, err = nsq.Ready(1).WriteTo(rw) test.Nil(t, err) - topic := nsqd.GetOrCreateTopic(topicName) + topic, _ := nsqd.GetOrCreateTopic(topicName) msg := NewMessage(topic.GenerateID(), msgBody) topic.PutMessage(msg) @@ -1232,12 +1233,12 @@ func TestSampling(t *testing.T) { test.Equal(t, int32(sampleRate), r.SampleRate) topicName := "test_sampling" + strconv.Itoa(int(time.Now().Unix())) - topic := nsqd.GetOrCreateTopic(topicName) + topic, _ := nsqd.GetOrCreateTopic(topicName) for i := 0; i < num; i++ { msg := NewMessage(topic.GenerateID(), []byte("test body")) topic.PutMessage(msg) } - channel := topic.GetOrCreateChannel("ch") + channel, _ := topic.GetOrCreateChannel("ch") // let the topic drain into the channel time.Sleep(50 * time.Millisecond) @@ -1336,8 +1337,8 @@ func TestClientMsgTimeout(t *testing.T) { defer nsqd.Exit() topicName := "test_cmsg_timeout" + strconv.Itoa(int(time.Now().Unix())) - topic := nsqd.GetOrCreateTopic(topicName) - ch := topic.GetOrCreateChannel("ch") + topic, _ := nsqd.GetOrCreateTopic(topicName) + ch, _ := topic.GetOrCreateChannel("ch") msg := NewMessage(topic.GenerateID(), make([]byte, 100)) topic.PutMessage(msg) @@ -1429,8 +1430,8 @@ func TestReqTimeoutRange(t *testing.T) { identify(t, conn, nil, frameTypeResponse) sub(t, conn, topicName, "ch") - topic := nsqd.GetOrCreateTopic(topicName) - channel := topic.GetOrCreateChannel("ch") + topic, _ := nsqd.GetOrCreateTopic(topicName) + channel, _ := topic.GetOrCreateChannel("ch") msg := NewMessage(topic.GenerateID(), []byte("test body")) topic.PutMessage(msg) @@ -1796,7 +1797,7 @@ func benchmarkProtocolV2Sub(b *testing.B, size int) { defer os.RemoveAll(opts.DataPath) msg := make([]byte, size) topicName := "bench_v2_sub" + strconv.Itoa(b.N) + strconv.Itoa(int(time.Now().Unix())) - topic := nsqd.GetOrCreateTopic(topicName) + topic, _ := nsqd.GetOrCreateTopic(topicName) for i := 0; i < b.N; i++ { msg := NewMessage(topic.GenerateID(), msg) topic.PutMessage(msg) @@ -1896,7 +1897,7 @@ func benchmarkProtocolV2MultiSub(b *testing.B, num int) { workers := runtime.GOMAXPROCS(0) for i := 0; i < num; i++ { topicName := "bench_v2" + strconv.Itoa(b.N) + "_" + strconv.Itoa(i) + "_" + strconv.Itoa(int(time.Now().Unix())) - topic := nsqd.GetOrCreateTopic(topicName) + topic, _ := nsqd.GetOrCreateTopic(topicName) for i := 0; i < b.N; i++ { msg := NewMessage(topic.GenerateID(), msg) topic.PutMessage(msg) diff --git a/nsqd/stats_test.go b/nsqd/stats_test.go index 2775d5aae..a1663ad5e 100644 --- a/nsqd/stats_test.go +++ b/nsqd/stats_test.go @@ -22,12 +22,12 @@ func TestStats(t *testing.T) { defer nsqd.Exit() topicName := "test_stats" + strconv.Itoa(int(time.Now().Unix())) - topic := nsqd.GetOrCreateTopic(topicName) + topic, _ := nsqd.GetOrCreateTopic(topicName) msg := NewMessage(topic.GenerateID(), []byte("test body")) topic.PutMessage(msg) accompanyTopicName := "accompany_test_stats" + strconv.Itoa(int(time.Now().Unix())) - accompanyTopic := nsqd.GetOrCreateTopic(accompanyTopicName) + accompanyTopic, _ := nsqd.GetOrCreateTopic(accompanyTopicName) msg = NewMessage(accompanyTopic.GenerateID(), []byte("accompany test body")) accompanyTopic.PutMessage(msg) @@ -126,8 +126,8 @@ func TestStatsChannelLocking(t *testing.T) { defer nsqd.Exit() topicName := "test_channel_empty" + strconv.Itoa(int(time.Now().Unix())) - topic := nsqd.GetOrCreateTopic(topicName) - channel := topic.GetOrCreateChannel("channel") + topic, _ := nsqd.GetOrCreateTopic(topicName) + channel, _ := topic.GetOrCreateChannel("channel") var wg sync.WaitGroup diff --git a/nsqd/topic.go b/nsqd/topic.go index f8be81087..76e3dcdb3 100644 --- a/nsqd/topic.go +++ b/nsqd/topic.go @@ -134,12 +134,12 @@ func (t *Topic) Exiting() bool { // to return a Channel object (potentially new) // // The creation might fail if the topic is draining and no messages are outstanding -func (t *Topic) GetOrCreateChannel(channelName string) *Channel { +func (t *Topic) GetOrCreateChannel(channelName string) (*Channel, error) { t.Lock() - channel, isNew := t.getOrCreateChannel(channelName) + channel, isNew, err := t.getOrCreateChannel(channelName) t.Unlock() - if isNew { + if isNew && err != nil { // update messagePump state select { case t.channelUpdateChan <- 1: @@ -147,17 +147,17 @@ func (t *Topic) GetOrCreateChannel(channelName string) *Channel { } } - return channel + return channel, err } // getOrCreateChannel expects the caller to handle locking -func (t *Topic) getOrCreateChannel(channelName string) (*Channel, bool) { +func (t *Topic) getOrCreateChannel(channelName string) (c *Channel, isNew bool, err error) { channel, ok := t.channelMap[channelName] if !ok { if atomic.LoadInt32(&t.isDraining) == 1 { // if this topic is draining, and there are no messages on the topic don't create a new channel if t.Depth() == 0 { - return nil, false + return nil, false, errors.New("topic draining") } } @@ -178,9 +178,9 @@ func (t *Topic) getOrCreateChannel(channelName string) (*Channel, bool) { channel = NewChannel(t.name, channelName, t.nsqd, deleteCallback) t.channelMap[channelName] = channel t.nsqd.logf(LOG_INFO, "TOPIC(%s): new channel(%s)", t.name, channel.name) - return channel, true + return channel, true, nil } - return channel, false + return channel, false, nil } func (t *Topic) GetExistingChannel(channelName string) (*Channel, error) { diff --git a/nsqd/topic_test.go b/nsqd/topic_test.go index 173844e5c..42634f918 100644 --- a/nsqd/topic_test.go +++ b/nsqd/topic_test.go @@ -23,14 +23,14 @@ func TestGetTopic(t *testing.T) { defer os.RemoveAll(opts.DataPath) defer nsqd.Exit() - topic1 := nsqd.GetOrCreateTopic("test") + topic1, _ := nsqd.GetOrCreateTopic("test") test.NotNil(t, topic1) test.Equal(t, "test", topic1.name) - topic2 := nsqd.GetOrCreateTopic("test") + topic2, _ := nsqd.GetOrCreateTopic("test") test.Equal(t, topic1, topic2) - topic3 := nsqd.GetOrCreateTopic("test2") + topic3, _ := nsqd.GetOrCreateTopic("test2") test.Equal(t, "test2", topic3.name) test.NotEqual(t, topic2, topic3) } @@ -42,13 +42,13 @@ func TestGetChannel(t *testing.T) { defer os.RemoveAll(opts.DataPath) defer nsqd.Exit() - topic := nsqd.GetOrCreateTopic("test") + topic, _ := nsqd.GetOrCreateTopic("test") - channel1 := topic.GetOrCreateChannel("ch1") + channel1, _ := topic.GetOrCreateChannel("ch1") test.NotNil(t, channel1) test.Equal(t, "ch1", channel1.name) - channel2 := topic.GetOrCreateChannel("ch2") + channel2, _ := topic.GetOrCreateChannel("ch2") test.Equal(t, channel1, topic.channelMap["ch1"]) test.Equal(t, channel2, topic.channelMap["ch2"]) @@ -75,7 +75,7 @@ func TestHealth(t *testing.T) { defer os.RemoveAll(opts.DataPath) defer nsqd.Exit() - topic := nsqd.GetOrCreateTopic("test") + topic, _ := nsqd.GetOrCreateTopic("test") topic.backend = &errorBackendQueue{} msg := NewMessage(topic.GenerateID(), make([]byte, 100)) @@ -123,16 +123,16 @@ func TestDeletes(t *testing.T) { defer os.RemoveAll(opts.DataPath) defer nsqd.Exit() - topic := nsqd.GetOrCreateTopic("test") + topic, _ := nsqd.GetOrCreateTopic("test") - channel1 := topic.GetOrCreateChannel("ch1") + channel1, _ := topic.GetOrCreateChannel("ch1") test.NotNil(t, channel1) err := topic.DeleteExistingChannel("ch1") test.Nil(t, err) test.Equal(t, 0, len(topic.channelMap)) - channel2 := topic.GetOrCreateChannel("ch2") + channel2, _ := topic.GetOrCreateChannel("ch2") test.NotNil(t, channel2) err = nsqd.DeleteExistingTopic("test") @@ -148,9 +148,9 @@ func TestDeleteLast(t *testing.T) { defer os.RemoveAll(opts.DataPath) defer nsqd.Exit() - topic := nsqd.GetOrCreateTopic("test") + topic, _ := nsqd.GetOrCreateTopic("test") - channel1 := topic.GetOrCreateChannel("ch1") + channel1, _ := topic.GetOrCreateChannel("ch1") test.NotNil(t, channel1) err := topic.DeleteExistingChannel("ch1") @@ -172,11 +172,11 @@ func TestPause(t *testing.T) { defer nsqd.Exit() topicName := "test_topic_pause" + strconv.Itoa(int(time.Now().Unix())) - topic := nsqd.GetOrCreateTopic(topicName) + topic, _ := nsqd.GetOrCreateTopic(topicName) err := topic.Pause() test.Nil(t, err) - channel := topic.GetOrCreateChannel("ch1") + channel, _ := topic.GetOrCreateChannel("ch1") test.NotNil(t, channel) msg := NewMessage(topic.GenerateID(), []byte("aaaaaaaaaaaaaaaaaaaaaaaaaaa")) @@ -204,11 +204,11 @@ func TestDrainEmpty(t *testing.T) { defer os.RemoveAll(opts.DataPath) defer nsqd.Exit() - topic := nsqd.GetOrCreateTopic("drain_topic_empty") + topic, _ := nsqd.GetOrCreateTopic("drain_topic_empty") - channel1 := topic.GetOrCreateChannel("ch1") + channel1, _ := topic.GetOrCreateChannel("ch1") test.NotNil(t, channel1) - channel2 := topic.GetOrCreateChannel("ch2") + channel2, _ := topic.GetOrCreateChannel("ch2") test.NotNil(t, channel2) test.Equal(t, 2, len(topic.channelMap)) @@ -243,7 +243,7 @@ func BenchmarkTopicPut(b *testing.B) { b.StartTimer() for i := 0; i <= b.N; i++ { - topic := nsqd.GetOrCreateTopic(topicName) + topic, _ := nsqd.GetOrCreateTopic(topicName) msg := NewMessage(guid(gf.NextMessageID()).Hex(), []byte("aaaaaaaaaaaaaaaaaaaaaaaaaaa")) topic.PutMessage(msg) } @@ -260,12 +260,12 @@ func BenchmarkTopicToChannelPut(b *testing.B) { _, _, nsqd := mustStartNSQD(opts) defer os.RemoveAll(opts.DataPath) defer nsqd.Exit() - channel := nsqd.GetOrCreateTopic(topicName).GetOrCreateChannel(channelName) + topic, _ := nsqd.GetOrCreateTopic(topicName) + channel, _ := topic.GetOrCreateChannel(channelName) gf := &test.GUIDFactory{} b.StartTimer() for i := 0; i <= b.N; i++ { - topic := nsqd.GetOrCreateTopic(topicName) msg := NewMessage(guid(gf.NextMessageID()).Hex(), []byte("aaaaaaaaaaaaaaaaaaaaaaaaaaa")) topic.PutMessage(msg) } @@ -289,8 +289,8 @@ func BenchmarkTopicMessagePump(b *testing.B) { defer os.RemoveAll(opts.DataPath) defer nsqd.Exit() - topic := nsqd.GetOrCreateTopic(topicName) - ch := topic.GetOrCreateChannel("ch") + topic, _ := nsqd.GetOrCreateTopic(topicName) + ch, _ := topic.GetOrCreateChannel("ch") ctx, cancel := context.WithCancel(context.Background()) gf := &test.GUIDFactory{}