diff --git a/channelz/service/service_test.go b/channelz/service/service_test.go index 4e4154226bbb..8214f1223691 100644 --- a/channelz/service/service_test.go +++ b/channelz/service/service_test.go @@ -334,7 +334,7 @@ func (s) TestGetChannel(t *testing.T) { }, }) - subChan := channelz.RegisterSubChannel(cids[0].ID, refNames[2]) + subChan := channelz.RegisterSubChannel(cids[0], refNames[2]) channelz.AddTraceEvent(logger, subChan, 0, &channelz.TraceEvent{ Desc: "SubChannel Created", Severity: channelz.CtInfo, @@ -432,7 +432,7 @@ func (s) TestGetSubChannel(t *testing.T) { Desc: "Channel Created", Severity: channelz.CtInfo, }) - subChan := channelz.RegisterSubChannel(chann.ID, refNames[1]) + subChan := channelz.RegisterSubChannel(chann, refNames[1]) defer channelz.RemoveEntry(subChan.ID) channelz.AddTraceEvent(logger, subChan, 0, &channelz.TraceEvent{ Desc: subchanCreated, diff --git a/clientconn.go b/clientconn.go index d16d058b053a..c7f2607114a8 100644 --- a/clientconn.go +++ b/clientconn.go @@ -833,7 +833,7 @@ func (cc *ClientConn) newAddrConnLocked(addrs []resolver.Address, opts balancer. addrs: copyAddressesWithoutBalancerAttributes(addrs), scopts: opts, dopts: cc.dopts, - channelz: channelz.RegisterSubChannel(cc.channelz.ID, ""), + channelz: channelz.RegisterSubChannel(cc.channelz, ""), resetBackoff: make(chan struct{}), stateChan: make(chan struct{}), } diff --git a/internal/channelz/funcs.go b/internal/channelz/funcs.go index f461e9bc3baf..03e24e1507aa 100644 --- a/internal/channelz/funcs.go +++ b/internal/channelz/funcs.go @@ -143,20 +143,21 @@ func RegisterChannel(parent *Channel, target string) *Channel { // Returns a unique channelz identifier assigned to this subChannel. // // If channelz is not turned ON, the channelz database is not mutated. -func RegisterSubChannel(pid int64, ref string) *SubChannel { +func RegisterSubChannel(parent *Channel, ref string) *SubChannel { id := IDGen.genID() - if !IsOn() { - return &SubChannel{ID: id} - } - sc := &SubChannel{ - RefName: ref, ID: id, - sockets: make(map[int64]string), - parent: db.getChannel(pid), - trace: &ChannelTrace{CreationTime: time.Now(), Events: make([]*traceEvent, 0, getMaxTraceEntry())}, + RefName: ref, + parent: parent, } - db.addSubChannel(id, sc, pid) + + if !IsOn() { + return sc + } + + sc.sockets = make(map[int64]string) + sc.trace = &ChannelTrace{CreationTime: time.Now(), Events: make([]*traceEvent, 0, getMaxTraceEntry())} + db.addSubChannel(id, sc, parent.ID) return sc } diff --git a/internal/transport/keepalive_test.go b/internal/transport/keepalive_test.go index 3fafc38918dd..393a4540396f 100644 --- a/internal/transport/keepalive_test.go +++ b/internal/transport/keepalive_test.go @@ -249,6 +249,16 @@ func (s) TestKeepaliveServerWithResponsiveClient(t *testing.T) { } } +func channelzSubChannel(t *testing.T) *channelz.SubChannel { + ch := channelz.RegisterChannel(nil, "test chan") + sc := channelz.RegisterSubChannel(ch, "test subchan") + t.Cleanup(func() { + channelz.RemoveEntry(sc.ID) + channelz.RemoveEntry(ch.ID) + }) + return sc +} + // TestKeepaliveClientClosesUnresponsiveServer creates a server which does not // respond to keepalive pings, and makes sure that the client closes the // transport once the keepalive logic kicks in. Here, we set the @@ -257,14 +267,13 @@ func (s) TestKeepaliveServerWithResponsiveClient(t *testing.T) { func (s) TestKeepaliveClientClosesUnresponsiveServer(t *testing.T) { connCh := make(chan net.Conn, 1) copts := ConnectOptions{ - ChannelzParent: channelz.RegisterSubChannel(-1, "test subchan"), + ChannelzParent: channelzSubChannel(t), KeepaliveParams: keepalive.ClientParameters{ Time: 10 * time.Millisecond, Timeout: 10 * time.Millisecond, PermitWithoutStream: true, }, } - defer channelz.RemoveEntry(copts.ChannelzParent.ID) client, cancel := setUpWithNoPingServer(t, copts, connCh) defer cancel() defer client.Close(fmt.Errorf("closed manually by test")) @@ -288,13 +297,12 @@ func (s) TestKeepaliveClientClosesUnresponsiveServer(t *testing.T) { func (s) TestKeepaliveClientOpenWithUnresponsiveServer(t *testing.T) { connCh := make(chan net.Conn, 1) copts := ConnectOptions{ - ChannelzParent: channelz.RegisterSubChannel(-1, "test subchan"), + ChannelzParent: channelzSubChannel(t), KeepaliveParams: keepalive.ClientParameters{ Time: 10 * time.Millisecond, Timeout: 10 * time.Millisecond, }, } - defer channelz.RemoveEntry(copts.ChannelzParent.ID) client, cancel := setUpWithNoPingServer(t, copts, connCh) defer cancel() defer client.Close(fmt.Errorf("closed manually by test")) @@ -319,13 +327,12 @@ func (s) TestKeepaliveClientOpenWithUnresponsiveServer(t *testing.T) { func (s) TestKeepaliveClientClosesWithActiveStreams(t *testing.T) { connCh := make(chan net.Conn, 1) copts := ConnectOptions{ - ChannelzParent: channelz.RegisterSubChannel(-1, "test subchan"), + ChannelzParent: channelzSubChannel(t), KeepaliveParams: keepalive.ClientParameters{ Time: 500 * time.Millisecond, Timeout: 500 * time.Millisecond, }, } - defer channelz.RemoveEntry(copts.ChannelzParent.ID) // TODO(i/6099): Setup a server which can ping and no-ping based on a flag to // reduce the flakiness in this test. client, cancel := setUpWithNoPingServer(t, copts, connCh) diff --git a/internal/transport/transport_test.go b/internal/transport/transport_test.go index 90ce78f42f7a..b0be89210564 100644 --- a/internal/transport/transport_test.go +++ b/internal/transport/transport_test.go @@ -434,8 +434,7 @@ func setUp(t *testing.T, port int, ht hType) (*server, *http2Client, func()) { func setUpWithOptions(t *testing.T, port int, sc *ServerConfig, ht hType, copts ConnectOptions) (*server, *http2Client, func()) { server := setUpServerOnly(t, port, sc, ht) addr := resolver.Address{Addr: "localhost:" + server.port} - copts.ChannelzParent = channelz.RegisterSubChannel(-1, "test channel") - t.Cleanup(func() { channelz.RemoveEntry(copts.ChannelzParent.ID) }) + copts.ChannelzParent = channelzSubChannel(t) connectCtx, cancel := context.WithDeadline(context.Background(), time.Now().Add(2*time.Second)) ct, connErr := NewClientTransport(connectCtx, context.Background(), addr, copts, func(GoAwayReason) {}) @@ -1321,9 +1320,8 @@ func (s) TestClientHonorsConnectContext(t *testing.T) { connectCtx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) time.AfterFunc(100*time.Millisecond, cancel) - parent := channelz.RegisterSubChannel(-1, "test channel") + parent := channelzSubChannel(t) copts := ConnectOptions{ChannelzParent: parent} - defer channelz.RemoveEntry(parent.ID) _, err = NewClientTransport(connectCtx, context.Background(), resolver.Address{Addr: lis.Addr().String()}, copts, func(GoAwayReason) {}) if err == nil { t.Fatalf("NewClientTransport() returned successfully; wanted error") @@ -1414,8 +1412,7 @@ func (s) TestClientWithMisbehavedServer(t *testing.T) { connectCtx, cancel := context.WithDeadline(context.Background(), time.Now().Add(2*time.Second)) defer cancel() - parent := channelz.RegisterSubChannel(-1, "test channel") - defer channelz.RemoveEntry(parent.ID) + parent := channelzSubChannel(t) copts := ConnectOptions{ChannelzParent: parent} ct, err := NewClientTransport(connectCtx, context.Background(), resolver.Address{Addr: lis.Addr().String()}, copts, func(GoAwayReason) {}) if err != nil { @@ -2425,9 +2422,8 @@ func (s) TestClientHandshakeInfo(t *testing.T) { copts := ConnectOptions{ TransportCredentials: creds, - ChannelzParent: channelz.RegisterSubChannel(-1, "test subchannel"), + ChannelzParent: channelzSubChannel(t), } - defer channelz.RemoveEntry(copts.ChannelzParent.ID) tr, err := NewClientTransport(ctx, context.Background(), addr, copts, func(GoAwayReason) {}) if err != nil { t.Fatalf("NewClientTransport(): %v", err) @@ -2467,9 +2463,8 @@ func (s) TestClientHandshakeInfoDialer(t *testing.T) { copts := ConnectOptions{ Dialer: dialer, - ChannelzParent: channelz.RegisterSubChannel(-1, "test subchannel"), + ChannelzParent: channelzSubChannel(t), } - defer channelz.RemoveEntry(copts.ChannelzParent.ID) tr, err := NewClientTransport(ctx, context.Background(), addr, copts, func(GoAwayReason) {}) if err != nil { t.Fatalf("NewClientTransport(): %v", err) diff --git a/test/channelz_test.go b/test/channelz_test.go index 1cc1e2e48de5..cd6b77af4469 100644 --- a/test/channelz_test.go +++ b/test/channelz_test.go @@ -554,8 +554,8 @@ func (s) TestCZRecusivelyDeletionOfEntry(t *testing.T) { // Socket1 Socket2 topChan := channelz.RegisterChannel(nil, "") - subChan1 := channelz.RegisterSubChannel(topChan.ID, "") - subChan2 := channelz.RegisterSubChannel(topChan.ID, "") + subChan1 := channelz.RegisterSubChannel(topChan, "") + subChan2 := channelz.RegisterSubChannel(topChan, "") skt1 := channelz.RegisterSocket(&channelz.Socket{SocketType: channelz.SocketTypeNormal, Parent: subChan1}) skt2 := channelz.RegisterSocket(&channelz.Socket{SocketType: channelz.SocketTypeNormal, Parent: subChan1})