From 317271b232677b7869576a49855b01b9f4775d67 Mon Sep 17 00:00:00 2001 From: Arjan Singh Bal <46515553+arjan-bal@users.noreply.github.com> Date: Thu, 5 Dec 2024 10:50:07 +0530 Subject: [PATCH] pickfirst: Register a health listener when used as a leaf policy (#7832) --- .../pickfirst/pickfirstleaf/pickfirstleaf.go | 173 +++++++++--- .../pickfirstleaf/pickfirstleaf_ext_test.go | 248 ++++++++++++++++++ 2 files changed, 384 insertions(+), 37 deletions(-) diff --git a/balancer/pickfirst/pickfirstleaf/pickfirstleaf.go b/balancer/pickfirst/pickfirstleaf/pickfirstleaf.go index 7a250e6e3217..2fc0a71f9441 100644 --- a/balancer/pickfirst/pickfirstleaf/pickfirstleaf.go +++ b/balancer/pickfirst/pickfirstleaf/pickfirstleaf.go @@ -54,6 +54,10 @@ func init() { balancer.Register(pickfirstBuilder{}) } +// enableHealthListenerKeyType is a unique key type used in resolver attributes +// to indicate whether the health listener usage is enabled. +type enableHealthListenerKeyType struct{} + var ( logger = grpclog.Component("pick-first-leaf-lb") // Name is the name of the pick_first_leaf balancer. @@ -109,10 +113,8 @@ func (pickfirstBuilder) Build(cc balancer.ClientConn, bo balancer.BuildOptions) target: bo.Target.String(), metricsRecorder: bo.MetricsRecorder, // ClientConn will always create a Metrics Recorder. - addressList: addressList{}, subConns: resolver.NewAddressMap(), state: connectivity.Connecting, - mu: sync.Mutex{}, cancelConnectionTimer: func() {}, } b.logger = internalgrpclog.NewPrefixLogger(logger, fmt.Sprintf(logPrefix, b)) @@ -131,6 +133,13 @@ func (pickfirstBuilder) ParseConfig(js json.RawMessage) (serviceconfig.LoadBalan return cfg, nil } +// EnableHealthListener updates the state to configure pickfirst for using a +// generic health listener. +func EnableHealthListener(state resolver.State) resolver.State { + state.Attributes = state.Attributes.WithValue(enableHealthListenerKeyType{}, true) + return state +} + type pfConfig struct { serviceconfig.LoadBalancingConfig `json:"-"` @@ -148,15 +157,19 @@ type scData struct { subConn balancer.SubConn addr resolver.Address - state connectivity.State + rawConnectivityState connectivity.State + // The effective connectivity state based on raw connectivity, health state + // and after following sticky TransientFailure behaviour defined in A62. + effectiveState connectivity.State lastErr error connectionFailedInFirstPass bool } func (b *pickfirstBalancer) newSCData(addr resolver.Address) (*scData, error) { sd := &scData{ - state: connectivity.Idle, - addr: addr, + rawConnectivityState: connectivity.Idle, + effectiveState: connectivity.Idle, + addr: addr, } sc, err := b.cc.NewSubConn([]resolver.Address{addr}, balancer.NewSubConnOptions{ StateListener: func(state balancer.SubConnState) { @@ -181,7 +194,9 @@ type pickfirstBalancer struct { // The mutex is used to ensure synchronization of updates triggered // from the idle picker and the already serialized resolver, // SubConn state updates. - mu sync.Mutex + mu sync.Mutex + // State reported to the channel based on SubConn states and resolver + // updates. state connectivity.State // scData for active subonns mapped by address. subConns *resolver.AddressMap @@ -189,6 +204,7 @@ type pickfirstBalancer struct { firstPass bool numTF int cancelConnectionTimer func() + healthCheckingEnabled bool } // ResolverError is called by the ClientConn when the name resolver produces @@ -214,7 +230,7 @@ func (b *pickfirstBalancer) resolverErrorLocked(err error) { return } - b.cc.UpdateState(balancer.State{ + b.updateBalancerState(balancer.State{ ConnectivityState: connectivity.TransientFailure, Picker: &picker{err: fmt.Errorf("name resolver error: %v", err)}, }) @@ -227,12 +243,12 @@ func (b *pickfirstBalancer) UpdateClientConnState(state balancer.ClientConnState if len(state.ResolverState.Addresses) == 0 && len(state.ResolverState.Endpoints) == 0 { // Cleanup state pertaining to the previous resolver state. // Treat an empty address list like an error by calling b.ResolverError. - b.state = connectivity.TransientFailure b.closeSubConnsLocked() b.addressList.updateAddrs(nil) b.resolverErrorLocked(errors.New("produced zero addresses")) return balancer.ErrBadResolverState } + b.healthCheckingEnabled = state.ResolverState.Attributes.Value(enableHealthListenerKeyType{}) != nil cfg, ok := state.BalancerConfig.(pfConfig) if state.BalancerConfig != nil && !ok { return fmt.Errorf("pickfirst: received illegal BalancerConfig (type %T): %v: %w", state.BalancerConfig, state.BalancerConfig, balancer.ErrBadResolverState) @@ -279,12 +295,15 @@ func (b *pickfirstBalancer) UpdateClientConnState(state balancer.ClientConnState newAddrs = deDupAddresses(newAddrs) newAddrs = interleaveAddresses(newAddrs) - // If the previous ready SubConn exists in new address list, - // keep this connection and don't create new SubConns. prevAddr := b.addressList.currentAddress() + prevSCData, found := b.subConns.Get(prevAddr) prevAddrsCount := b.addressList.size() + isPrevRawConnectivityStateReady := found && prevSCData.(*scData).rawConnectivityState == connectivity.Ready b.addressList.updateAddrs(newAddrs) - if b.state == connectivity.Ready && b.addressList.seekTo(prevAddr) { + + // If the previous ready SubConn exists in new address list, + // keep this connection and don't create new SubConns. + if isPrevRawConnectivityStateReady && b.addressList.seekTo(prevAddr) { return nil } @@ -296,10 +315,9 @@ func (b *pickfirstBalancer) UpdateClientConnState(state balancer.ClientConnState // we should still enter CONNECTING because the sticky TF behaviour // mentioned in A62 applies only when the TRANSIENT_FAILURE is reported // due to connectivity failures. - if b.state == connectivity.Ready || b.state == connectivity.Connecting || prevAddrsCount == 0 { + if isPrevRawConnectivityStateReady || b.state == connectivity.Connecting || prevAddrsCount == 0 { // Start connection attempt at first address. - b.state = connectivity.Connecting - b.cc.UpdateState(balancer.State{ + b.forceUpdateConcludedStateLocked(balancer.State{ ConnectivityState: connectivity.Connecting, Picker: &picker{err: balancer.ErrNoSubConnAvailable}, }) @@ -501,7 +519,7 @@ func (b *pickfirstBalancer) requestConnectionLocked() { } scd := sd.(*scData) - switch scd.state { + switch scd.rawConnectivityState { case connectivity.Idle: scd.subConn.Connect() b.scheduleNextConnectionLocked() @@ -519,7 +537,7 @@ func (b *pickfirstBalancer) requestConnectionLocked() { b.scheduleNextConnectionLocked() return default: - b.logger.Errorf("SubConn with unexpected state %v present in SubConns map.", scd.state) + b.logger.Errorf("SubConn with unexpected state %v present in SubConns map.", scd.rawConnectivityState) return } @@ -562,16 +580,17 @@ func (b *pickfirstBalancer) scheduleNextConnectionLocked() { func (b *pickfirstBalancer) updateSubConnState(sd *scData, newState balancer.SubConnState) { b.mu.Lock() defer b.mu.Unlock() - oldState := sd.state - sd.state = newState.ConnectivityState + oldState := sd.rawConnectivityState + sd.rawConnectivityState = newState.ConnectivityState // Previously relevant SubConns can still callback with state updates. // To prevent pickers from returning these obsolete SubConns, this logic // is included to check if the current list of active SubConns includes this // SubConn. - if activeSD, found := b.subConns.Get(sd.addr); !found || activeSD != sd { + if !b.isActiveSCData(sd) { return } if newState.ConnectivityState == connectivity.Shutdown { + sd.effectiveState = connectivity.Shutdown return } @@ -590,10 +609,30 @@ func (b *pickfirstBalancer) updateSubConnState(sd *scData, newState balancer.Sub b.logger.Errorf("Address %q not found address list in %v", sd.addr, b.addressList.addresses) return } - b.state = connectivity.Ready - b.cc.UpdateState(balancer.State{ - ConnectivityState: connectivity.Ready, - Picker: &picker{result: balancer.PickResult{SubConn: sd.subConn}}, + if !b.healthCheckingEnabled { + if b.logger.V(2) { + b.logger.Infof("SubConn %p reported connectivity state READY and the health listener is disabled. Transitioning SubConn to READY.", sd.subConn) + } + + sd.effectiveState = connectivity.Ready + b.updateBalancerState(balancer.State{ + ConnectivityState: connectivity.Ready, + Picker: &picker{result: balancer.PickResult{SubConn: sd.subConn}}, + }) + return + } + if b.logger.V(2) { + b.logger.Infof("SubConn %p reported connectivity state READY. Registering health listener.", sd.subConn) + } + // Send a CONNECTING update to take the SubConn out of sticky-TF if + // required. + sd.effectiveState = connectivity.Connecting + b.updateBalancerState(balancer.State{ + ConnectivityState: connectivity.Connecting, + Picker: &picker{err: balancer.ErrNoSubConnAvailable}, + }) + sd.subConn.RegisterHealthListener(func(scs balancer.SubConnState) { + b.updateSubConnHealthState(sd, scs) }) return } @@ -604,11 +643,13 @@ func (b *pickfirstBalancer) updateSubConnState(sd *scData, newState balancer.Sub // a transport is successfully created, but the connection fails // before the SubConn can send the notification for READY. We treat // this as a successful connection and transition to IDLE. - if (b.state == connectivity.Ready && newState.ConnectivityState != connectivity.Ready) || (oldState == connectivity.Connecting && newState.ConnectivityState == connectivity.Idle) { + // TODO: https://github.com/grpc/grpc-go/issues/7862 - Remove the second + // part of the if condition below once the issue is fixed. + if oldState == connectivity.Ready || (oldState == connectivity.Connecting && newState.ConnectivityState == connectivity.Idle) { // Once a transport fails, the balancer enters IDLE and starts from // the first address when the picker is used. b.shutdownRemainingLocked(sd) - b.state = connectivity.Idle + sd.effectiveState = newState.ConnectivityState // READY SubConn interspliced in between CONNECTING and IDLE, need to // account for that. if oldState == connectivity.Connecting { @@ -619,7 +660,7 @@ func (b *pickfirstBalancer) updateSubConnState(sd *scData, newState balancer.Sub } disconnectionsMetric.Record(b.metricsRecorder, 1, b.target) b.addressList.reset() - b.cc.UpdateState(balancer.State{ + b.updateBalancerState(balancer.State{ ConnectivityState: connectivity.Idle, Picker: &idlePicker{exitIdle: sync.OnceFunc(b.ExitIdle)}, }) @@ -629,19 +670,19 @@ func (b *pickfirstBalancer) updateSubConnState(sd *scData, newState balancer.Sub if b.firstPass { switch newState.ConnectivityState { case connectivity.Connecting: - // The balancer can be in either IDLE, CONNECTING or - // TRANSIENT_FAILURE. If it's in TRANSIENT_FAILURE, stay in + // The effective state can be in either IDLE, CONNECTING or + // TRANSIENT_FAILURE. If it's TRANSIENT_FAILURE, stay in // TRANSIENT_FAILURE until it's READY. See A62. - // If the balancer is already in CONNECTING, no update is needed. - if b.state == connectivity.Idle { - b.state = connectivity.Connecting - b.cc.UpdateState(balancer.State{ + if sd.effectiveState != connectivity.TransientFailure { + sd.effectiveState = connectivity.Connecting + b.updateBalancerState(balancer.State{ ConnectivityState: connectivity.Connecting, Picker: &picker{err: balancer.ErrNoSubConnAvailable}, }) } case connectivity.TransientFailure: sd.lastErr = newState.ConnectionError + sd.effectiveState = connectivity.TransientFailure // Since we're re-using common SubConns while handling resolver // updates, we could receive an out of turn TRANSIENT_FAILURE from // a pass over the previous address list. Happy Eyeballs will also @@ -668,7 +709,7 @@ func (b *pickfirstBalancer) updateSubConnState(sd *scData, newState balancer.Sub b.numTF = (b.numTF + 1) % b.subConns.Len() sd.lastErr = newState.ConnectionError if b.numTF%b.subConns.Len() == 0 { - b.cc.UpdateState(balancer.State{ + b.updateBalancerState(balancer.State{ ConnectivityState: connectivity.TransientFailure, Picker: &picker{err: newState.ConnectionError}, }) @@ -698,21 +739,79 @@ func (b *pickfirstBalancer) endFirstPassIfPossibleLocked(lastErr error) { } } b.firstPass = false - b.state = connectivity.TransientFailure - - b.cc.UpdateState(balancer.State{ + b.updateBalancerState(balancer.State{ ConnectivityState: connectivity.TransientFailure, Picker: &picker{err: lastErr}, }) // Start re-connecting all the SubConns that are already in IDLE. for _, v := range b.subConns.Values() { sd := v.(*scData) - if sd.state == connectivity.Idle { + if sd.rawConnectivityState == connectivity.Idle { sd.subConn.Connect() } } } +func (b *pickfirstBalancer) isActiveSCData(sd *scData) bool { + activeSD, found := b.subConns.Get(sd.addr) + return found && activeSD == sd +} + +func (b *pickfirstBalancer) updateSubConnHealthState(sd *scData, state balancer.SubConnState) { + b.mu.Lock() + defer b.mu.Unlock() + // Previously relevant SubConns can still callback with state updates. + // To prevent pickers from returning these obsolete SubConns, this logic + // is included to check if the current list of active SubConns includes + // this SubConn. + if !b.isActiveSCData(sd) { + return + } + sd.effectiveState = state.ConnectivityState + switch state.ConnectivityState { + case connectivity.Ready: + b.updateBalancerState(balancer.State{ + ConnectivityState: connectivity.Ready, + Picker: &picker{result: balancer.PickResult{SubConn: sd.subConn}}, + }) + case connectivity.TransientFailure: + b.updateBalancerState(balancer.State{ + ConnectivityState: connectivity.TransientFailure, + Picker: &picker{err: fmt.Errorf("pickfirst: health check failure: %v", state.ConnectionError)}, + }) + case connectivity.Connecting: + b.updateBalancerState(balancer.State{ + ConnectivityState: connectivity.Connecting, + Picker: &picker{err: balancer.ErrNoSubConnAvailable}, + }) + default: + b.logger.Errorf("Got unexpected health update for SubConn %p: %v", state) + } +} + +// updateBalancerState stores the state reported to the channel and calls +// ClientConn.UpdateState(). As an optimization, it avoids sending duplicate +// updates to the channel. +func (b *pickfirstBalancer) updateBalancerState(newState balancer.State) { + // In case of TransientFailures allow the picker to be updated to update + // the connectivity error, in all other cases don't send duplicate state + // updates. + if newState.ConnectivityState == b.state && b.state != connectivity.TransientFailure { + return + } + b.forceUpdateConcludedStateLocked(newState) +} + +// forceUpdateConcludedStateLocked stores the state reported to the channel and +// calls ClientConn.UpdateState(). +// A separate function is defined to force update the ClientConn state since the +// channel doesn't correctly assume that LB policies start in CONNECTING and +// relies on LB policy to send an initial CONNECTING update. +func (b *pickfirstBalancer) forceUpdateConcludedStateLocked(newState balancer.State) { + b.state = newState.ConnectivityState + b.cc.UpdateState(newState) +} + type picker struct { result balancer.PickResult err error diff --git a/balancer/pickfirst/pickfirstleaf/pickfirstleaf_ext_test.go b/balancer/pickfirst/pickfirstleaf/pickfirstleaf_ext_test.go index 9e835a6731b8..9667c2b3db6b 100644 --- a/balancer/pickfirst/pickfirstleaf/pickfirstleaf_ext_test.go +++ b/balancer/pickfirst/pickfirstleaf/pickfirstleaf_ext_test.go @@ -34,6 +34,7 @@ import ( "google.golang.org/grpc/connectivity" "google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/internal" + "google.golang.org/grpc/internal/balancer/stub" "google.golang.org/grpc/internal/grpcsync" "google.golang.org/grpc/internal/grpctest" "google.golang.org/grpc/internal/stubserver" @@ -1224,6 +1225,253 @@ func (s) TestPickFirstLeaf_InterleavingUnknownPreffered(t *testing.T) { } } +// Test verifies that pickfirst balancer transitions to READY when the health +// listener is enabled. Since client side health checking is not enabled in +// the service config, the health listener will send a health update for READY +// after registering the listener. +func (s) TestPickFirstLeaf_HealthListenerEnabled(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + bf := stub.BalancerFuncs{ + Init: func(bd *stub.BalancerData) { + bd.Data = balancer.Get(pickfirstleaf.Name).Build(bd.ClientConn, bd.BuildOptions) + }, + Close: func(bd *stub.BalancerData) { + bd.Data.(balancer.Balancer).Close() + }, + UpdateClientConnState: func(bd *stub.BalancerData, ccs balancer.ClientConnState) error { + ccs.ResolverState = pickfirstleaf.EnableHealthListener(ccs.ResolverState) + return bd.Data.(balancer.Balancer).UpdateClientConnState(ccs) + }, + } + + stub.Register(t.Name(), bf) + svcCfg := fmt.Sprintf(`{ "loadBalancingConfig": [{%q: {}}] }`, t.Name()) + backend := stubserver.StartTestService(t, nil) + defer backend.Stop() + opts := []grpc.DialOption{ + grpc.WithTransportCredentials(insecure.NewCredentials()), + grpc.WithDefaultServiceConfig(svcCfg), + } + cc, err := grpc.NewClient(backend.Address, opts...) + if err != nil { + t.Fatalf("grpc.NewClient(%q) failed: %v", backend.Address, err) + + } + defer cc.Close() + + if err := pickfirst.CheckRPCsToBackend(ctx, cc, resolver.Address{Addr: backend.Address}); err != nil { + t.Fatal(err) + } +} + +// Test verifies that a health listener is not registered when pickfirst is not +// under a petiole policy. +func (s) TestPickFirstLeaf_HealthListenerNotEnabled(t *testing.T) { + // Wrap the clientconn to intercept NewSubConn. + // Capture the health list by wrapping the SC. + // Wrap the picker to unwrap the SC. + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + healthListenerCh := make(chan func(balancer.SubConnState)) + + bf := stub.BalancerFuncs{ + Init: func(bd *stub.BalancerData) { + ccw := &healthListenerCapturingCCWrapper{ + ClientConn: bd.ClientConn, + healthListenerCh: healthListenerCh, + subConnStateCh: make(chan balancer.SubConnState, 5), + } + bd.Data = balancer.Get(pickfirstleaf.Name).Build(ccw, bd.BuildOptions) + }, + Close: func(bd *stub.BalancerData) { + bd.Data.(balancer.Balancer).Close() + }, + UpdateClientConnState: func(bd *stub.BalancerData, ccs balancer.ClientConnState) error { + // Functions like a non-petiole policy by not configuring the use + // of health listeners. + return bd.Data.(balancer.Balancer).UpdateClientConnState(ccs) + }, + } + + stub.Register(t.Name(), bf) + svcCfg := fmt.Sprintf(`{ "loadBalancingConfig": [{%q: {}}] }`, t.Name()) + backend := stubserver.StartTestService(t, nil) + defer backend.Stop() + opts := []grpc.DialOption{ + grpc.WithTransportCredentials(insecure.NewCredentials()), + grpc.WithDefaultServiceConfig(svcCfg), + } + cc, err := grpc.NewClient(backend.Address, opts...) + if err != nil { + t.Fatalf("grpc.NewClient(%q) failed: %v", backend.Address, err) + + } + defer cc.Close() + cc.Connect() + + select { + case <-healthListenerCh: + t.Fatal("Health listener registered when not enabled.") + case <-time.After(defaultTestShortTimeout): + } + + testutils.AwaitState(ctx, t, cc, connectivity.Ready) +} + +// Test mocks the updates sent to the health listener and verifies that the +// balancer correctly reports the health state once the SubConn's connectivity +// state becomes READY. +func (s) TestPickFirstLeaf_HealthUpdates(t *testing.T) { + // Wrap the clientconn to intercept NewSubConn. + // Capture the health list by wrapping the SC. + // Wrap the picker to unwrap the SC. + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + healthListenerCh := make(chan func(balancer.SubConnState)) + scConnectivityStateCh := make(chan balancer.SubConnState, 5) + + bf := stub.BalancerFuncs{ + Init: func(bd *stub.BalancerData) { + ccw := &healthListenerCapturingCCWrapper{ + ClientConn: bd.ClientConn, + healthListenerCh: healthListenerCh, + subConnStateCh: scConnectivityStateCh, + } + bd.Data = balancer.Get(pickfirstleaf.Name).Build(ccw, bd.BuildOptions) + }, + Close: func(bd *stub.BalancerData) { + bd.Data.(balancer.Balancer).Close() + }, + UpdateClientConnState: func(bd *stub.BalancerData, ccs balancer.ClientConnState) error { + ccs.ResolverState = pickfirstleaf.EnableHealthListener(ccs.ResolverState) + return bd.Data.(balancer.Balancer).UpdateClientConnState(ccs) + }, + } + + stub.Register(t.Name(), bf) + svcCfg := fmt.Sprintf(`{ "loadBalancingConfig": [{%q: {}}] }`, t.Name()) + backend := stubserver.StartTestService(t, nil) + defer backend.Stop() + opts := []grpc.DialOption{ + grpc.WithTransportCredentials(insecure.NewCredentials()), + grpc.WithDefaultServiceConfig(svcCfg), + } + cc, err := grpc.NewClient(backend.Address, opts...) + if err != nil { + t.Fatalf("grpc.NewClient(%q) failed: %v", backend.Address, err) + + } + defer cc.Close() + cc.Connect() + + var healthListener func(balancer.SubConnState) + select { + case healthListener = <-healthListenerCh: + case <-ctx.Done(): + t.Fatal("Context timed out waiting for health listener to be registered.") + } + + // Wait for the raw connectivity state to become READY. The LB policy should + // wait for the health updates before transitioning the channel to READY. + for { + var scs balancer.SubConnState + select { + case scs = <-scConnectivityStateCh: + case <-ctx.Done(): + t.Fatal("Context timed out waiting for the SubConn connectivity state to become READY.") + } + if scs.ConnectivityState == connectivity.Ready { + break + } + } + + shortCtx, cancel := context.WithTimeout(ctx, defaultTestShortTimeout) + defer cancel() + testutils.AwaitNoStateChange(shortCtx, t, cc, connectivity.Connecting) + + // The LB policy should update the channel state based on the health state. + healthListener(balancer.SubConnState{ + ConnectivityState: connectivity.TransientFailure, + ConnectionError: fmt.Errorf("test health check failure"), + }) + testutils.AwaitState(ctx, t, cc, connectivity.TransientFailure) + + healthListener(balancer.SubConnState{ + ConnectivityState: connectivity.Connecting, + ConnectionError: balancer.ErrNoSubConnAvailable, + }) + testutils.AwaitState(ctx, t, cc, connectivity.Connecting) + + healthListener(balancer.SubConnState{ + ConnectivityState: connectivity.Ready, + }) + if err := pickfirst.CheckRPCsToBackend(ctx, cc, resolver.Address{Addr: backend.Address}); err != nil { + t.Fatal(err) + } + + // When the health check fails, the channel should transition to TF. + healthListener(balancer.SubConnState{ + ConnectivityState: connectivity.TransientFailure, + ConnectionError: fmt.Errorf("test health check failure"), + }) + testutils.AwaitState(ctx, t, cc, connectivity.TransientFailure) +} + +// healthListenerCapturingCCWrapper is used to capture the health listener so +// that health updates can be mocked for testing. +type healthListenerCapturingCCWrapper struct { + balancer.ClientConn + healthListenerCh chan func(balancer.SubConnState) + subConnStateCh chan balancer.SubConnState +} + +func (ccw *healthListenerCapturingCCWrapper) NewSubConn(addrs []resolver.Address, opts balancer.NewSubConnOptions) (balancer.SubConn, error) { + oldListener := opts.StateListener + opts.StateListener = func(scs balancer.SubConnState) { + ccw.subConnStateCh <- scs + if oldListener != nil { + oldListener(scs) + } + } + sc, err := ccw.ClientConn.NewSubConn(addrs, opts) + if err != nil { + return nil, err + } + return &healthListenerCapturingSCWrapper{ + SubConn: sc, + listenerCh: ccw.healthListenerCh, + }, nil +} + +func (ccw *healthListenerCapturingCCWrapper) UpdateState(state balancer.State) { + state.Picker = &unwrappingPicker{state.Picker} + ccw.ClientConn.UpdateState(state) +} + +type healthListenerCapturingSCWrapper struct { + balancer.SubConn + listenerCh chan func(balancer.SubConnState) +} + +func (scw *healthListenerCapturingSCWrapper) RegisterHealthListener(listener func(balancer.SubConnState)) { + scw.listenerCh <- listener +} + +// unwrappingPicker unwraps SubConns because the channel expects SubConns to be +// addrConns. +type unwrappingPicker struct { + balancer.Picker +} + +func (pw *unwrappingPicker) Pick(info balancer.PickInfo) (balancer.PickResult, error) { + pr, err := pw.Picker.Pick(info) + if pr.SubConn != nil { + pr.SubConn = pr.SubConn.(*healthListenerCapturingSCWrapper).SubConn + } + return pr, err +} + // subConnAddresses makes the pickfirst balancer create the requested number of // SubConns by triggering transient failures. The function returns the // addresses of the created SubConns.