diff --git a/internal/coord/behaviour.go b/internal/coord/behaviour.go index 7460994..21de6b0 100644 --- a/internal/coord/behaviour.go +++ b/internal/coord/behaviour.go @@ -146,3 +146,38 @@ func (w *Waiter[E]) Close() { func (w *Waiter[E]) Chan() <-chan WaiterEvent[E] { return w.pending } + +// NotifyCloserHook implements the [NotifyCloser] interface and provides hooks +// into the Notify and Close calls by wrapping another [NotifyCloser]. This is +// intended to be used in testing. +type NotifyCloserHook[E BehaviourEvent] struct { + nc NotifyCloser[E] + BeforeNotify func(context.Context, E) + AfterNotify func(context.Context, E) + BeforeClose func() + AfterClose func() +} + +var _ NotifyCloser[BehaviourEvent] = (*NotifyCloserHook[BehaviourEvent])(nil) + +func NewNotifyCloserHook[E BehaviourEvent](nc NotifyCloser[E]) *NotifyCloserHook[E] { + return &NotifyCloserHook[E]{ + nc: nc, + BeforeNotify: func(ctx context.Context, e E) {}, + AfterNotify: func(ctx context.Context, e E) {}, + BeforeClose: func() {}, + AfterClose: func() {}, + } +} + +func (n *NotifyCloserHook[E]) Notify(ctx context.Context, ev E) { + n.BeforeNotify(ctx, ev) + n.nc.Notify(ctx, ev) + n.AfterNotify(ctx, ev) +} + +func (n *NotifyCloserHook[E]) Close() { + n.BeforeClose() + n.nc.Close() + n.AfterClose() +} diff --git a/internal/coord/query_test.go b/internal/coord/query_test.go index 825b4bd..314826e 100644 --- a/internal/coord/query_test.go +++ b/internal/coord/query_test.go @@ -103,6 +103,7 @@ func TestPooledQuery_deadlock_regression(t *testing.T) { // start query waiter := NewWaiter[BehaviourEvent]() + wrappedWaiter := NewNotifyCloserHook[BehaviourEvent](waiter) waiterDone := make(chan struct{}) waiterMsg := make(chan struct{}) @@ -121,7 +122,7 @@ func TestPooledQuery_deadlock_regression(t *testing.T) { Target: msg.Target(), Message: msg, KnownClosestNodes: []kadt.PeerID{nodes[1].NodeID}, - Notify: waiter, + Notify: wrappedWaiter, NumResults: 0, }) @@ -144,6 +145,15 @@ func TestPooledQuery_deadlock_regression(t *testing.T) { ev, _ = c.queryBehaviour.Perform(ctx) require.IsType(t, &EventOutboundSendMessage{}, ev) + hasLock := make(chan struct{}) + var once sync.Once + wrappedWaiter.BeforeNotify = func(ctx context.Context, event BehaviourEvent) { + once.Do(func() { + require.IsType(t, &EventQueryProgressed{}, event) // verify test invariant + close(hasLock) + }) + } + // Simulate a successful response from the new node. This node didn't return // any new nodes to contact. This means the query pool behaviour will notify // the waiter about a query progression and afterward about a finished @@ -151,15 +161,17 @@ func TestPooledQuery_deadlock_regression(t *testing.T) { // of 1, the channel cannot hold both events. At the same time, the waiter // doesn't consume the messages because it's busy processing the previous // query event (because we haven't released the blocking waiterMsg call above). - var wg sync.WaitGroup - wg.Add(1) - go func() { - wg.Done() - c.queryBehaviour.Notify(ctx, successMsg(nodes[2].NodeID)) - }() + go c.queryBehaviour.Notify(ctx, successMsg(nodes[2].NodeID)) + + // wait until the above Notify call was handled by waiting until the hasLock + // channel was closed in the above BeforeNotify hook. If that hook is called + // we can be sure that the above Notify call has acquired the polled query + // behaviour's pendingMu lock. + kadtest.AssertClosed(t, ctx, hasLock) - wg.Wait() - <-waiterMsg + // Since we know that the pooled query behaviour holds the lock we can + // release the slow waiter by reading an item from the waiterMsg channel. + kadtest.ReadItem(t, ctx, waiterMsg) // At this point, the waitForQuery QueryFunc callback returned a // coordt.ErrSkipRemaining. This instructs the waitForQuery method to notify @@ -168,10 +180,5 @@ func TestPooledQuery_deadlock_regression(t *testing.T) { // lock on the pending events to process. Therefore, this notify call will // also block. At the same time, the waiter cannot read the new messages // from the query behaviour because it tries to notify it. - - select { - case <-waiterDone: - case <-ctx.Done(): - t.Fatalf("tiemout") - } + kadtest.AssertClosed(t, ctx, waiterDone) } diff --git a/internal/kadtest/chan.go b/internal/kadtest/chan.go new file mode 100644 index 0000000..e4f030e --- /dev/null +++ b/internal/kadtest/chan.go @@ -0,0 +1,36 @@ +package kadtest + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/stretchr/testify/require" +) + +func ReadItem[T any](t testing.TB, ctx context.Context, c <-chan T) T { + t.Helper() + + select { + case val, more := <-c: + require.True(t, more, "channel closed unexpectedly") + return val + case <-ctx.Done(): + t.Fatal("timeout reading item") + return *new(T) + } +} + +// AssertClosed triggers a test failure if the given channel was not closed but +// carried more values or a timeout occurs (given by the context). +func AssertClosed[T any](t testing.TB, ctx context.Context, c <-chan T) { + t.Helper() + + select { + case _, more := <-c: + assert.False(t, more) + case <-ctx.Done(): + t.Fatal("timeout closing channel") + } +} diff --git a/routing_test.go b/routing_test.go index e6807e9..a1d894f 100644 --- a/routing_test.go +++ b/routing_test.go @@ -21,7 +21,7 @@ import ( "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" - "github.com/plprobelab/zikade/internal/kadtest" + kadtest "github.com/plprobelab/zikade/internal/kadtest" "github.com/plprobelab/zikade/kadt" ) @@ -289,7 +289,7 @@ func TestDHT_FindProvidersAsync_empty_routing_table(t *testing.T) { c := newRandomContent(t) out := d.FindProvidersAsync(ctx, c, 1) - assertClosed(t, ctx, out) + kadtest.AssertClosed(t, ctx, out) } func TestDHT_FindProvidersAsync_dht_does_not_support_providers(t *testing.T) { @@ -300,7 +300,7 @@ func TestDHT_FindProvidersAsync_dht_does_not_support_providers(t *testing.T) { delete(d.backends, namespaceProviders) out := d.FindProvidersAsync(ctx, newRandomContent(t), 1) - assertClosed(t, ctx, out) + kadtest.AssertClosed(t, ctx, out) } func TestDHT_FindProvidersAsync_providers_stored_locally(t *testing.T) { @@ -315,10 +315,10 @@ func TestDHT_FindProvidersAsync_providers_stored_locally(t *testing.T) { out := d.FindProvidersAsync(ctx, c, 1) - val := readItem(t, ctx, out) + val := kadtest.ReadItem(t, ctx, out) assert.Equal(t, provider.ID, val.ID) - assertClosed(t, ctx, out) + kadtest.AssertClosed(t, ctx, out) } func TestDHT_FindProvidersAsync_returns_only_count_from_local_store(t *testing.T) { @@ -376,10 +376,10 @@ func TestDHT_FindProvidersAsync_queries_other_peers(t *testing.T) { out := d1.FindProvidersAsync(ctx, c, 1) - val := readItem(t, ctx, out) + val := kadtest.ReadItem(t, ctx, out) assert.Equal(t, provider.ID, val.ID) - assertClosed(t, ctx, out) + kadtest.AssertClosed(t, ctx, out) } func TestDHT_FindProvidersAsync_respects_cancelled_context_for_local_query(t *testing.T) { @@ -493,7 +493,7 @@ func TestDHT_FindProvidersAsync_datastore_error(t *testing.T) { be.datastore = dstore out := d.FindProvidersAsync(ctx, newRandomContent(t), 0) - assertClosed(t, ctx, out) + kadtest.AssertClosed(t, ctx, out) } func TestDHT_FindProvidersAsync_invalid_key(t *testing.T) { @@ -501,7 +501,7 @@ func TestDHT_FindProvidersAsync_invalid_key(t *testing.T) { d := newTestDHT(t) out := d.FindProvidersAsync(ctx, cid.Cid{}, 0) - assertClosed(t, ctx, out) + kadtest.AssertClosed(t, ctx, out) } func TestDHT_GetValue_happy_path(t *testing.T) { @@ -563,32 +563,6 @@ func TestDHT_GetValue_returns_not_found_error(t *testing.T) { assert.Nil(t, valueChan) } -// assertClosed triggers a test failure if the given channel was not closed but -// carried more values or a timeout occurs (given by the context). -func assertClosed[T any](t testing.TB, ctx context.Context, c <-chan T) { - t.Helper() - - select { - case _, more := <-c: - assert.False(t, more) - case <-ctx.Done(): - t.Fatal("timeout closing channel") - } -} - -func readItem[T any](t testing.TB, ctx context.Context, c <-chan T) T { - t.Helper() - - select { - case val, more := <-c: - require.True(t, more, "channel closed unexpectedly") - return val - case <-ctx.Done(): - t.Fatal("timeout reading item") - return *new(T) - } -} - func TestDHT_SearchValue_simple(t *testing.T) { // Test setup: // There is just one other server that returns a valid value. @@ -608,10 +582,10 @@ func TestDHT_SearchValue_simple(t *testing.T) { valChan, err := d1.SearchValue(ctx, key) require.NoError(t, err) - val := readItem(t, ctx, valChan) + val := kadtest.ReadItem(t, ctx, valChan) assert.Equal(t, v, val) - assertClosed(t, ctx, valChan) + kadtest.AssertClosed(t, ctx, valChan) } func TestDHT_SearchValue_returns_best_values(t *testing.T) { @@ -657,13 +631,13 @@ func TestDHT_SearchValue_returns_best_values(t *testing.T) { valChan, err := d1.SearchValue(ctx, key) require.NoError(t, err) - val := readItem(t, ctx, valChan) + val := kadtest.ReadItem(t, ctx, valChan) assert.Equal(t, validValue, val) - val = readItem(t, ctx, valChan) + val = kadtest.ReadItem(t, ctx, valChan) assert.Equal(t, betterValue, val) - assertClosed(t, ctx, valChan) + kadtest.AssertClosed(t, ctx, valChan) } // In order for 'go test' to run this suite, we need to create @@ -756,10 +730,10 @@ func (suite *SearchValueQuorumTestSuite) TestQuorumReachedPrematurely() { out, err := suite.d.SearchValue(ctx, suite.key, RoutingQuorum(3)) require.NoError(t, err) - val := readItem(t, ctx, out) + val := kadtest.ReadItem(t, ctx, out) assert.Equal(t, suite.validValue, val) - assertClosed(t, ctx, out) + kadtest.AssertClosed(t, ctx, out) } func (suite *SearchValueQuorumTestSuite) TestQuorumReachedAfterDiscoveryOfBetter() { @@ -768,13 +742,13 @@ func (suite *SearchValueQuorumTestSuite) TestQuorumReachedAfterDiscoveryOfBetter out, err := suite.d.SearchValue(ctx, suite.key, RoutingQuorum(5)) require.NoError(t, err) - val := readItem(t, ctx, out) + val := kadtest.ReadItem(t, ctx, out) assert.Equal(t, suite.validValue, val) - val = readItem(t, ctx, out) + val = kadtest.ReadItem(t, ctx, out) assert.Equal(t, suite.betterValue, val) - assertClosed(t, ctx, out) + kadtest.AssertClosed(t, ctx, out) } func (suite *SearchValueQuorumTestSuite) TestQuorumZero() { @@ -785,13 +759,13 @@ func (suite *SearchValueQuorumTestSuite) TestQuorumZero() { out, err := suite.d.SearchValue(ctx, suite.key, RoutingQuorum(0)) require.NoError(t, err) - val := readItem(t, ctx, out) + val := kadtest.ReadItem(t, ctx, out) assert.Equal(t, suite.validValue, val) - val = readItem(t, ctx, out) + val = kadtest.ReadItem(t, ctx, out) assert.Equal(t, suite.betterValue, val) - assertClosed(t, ctx, out) + kadtest.AssertClosed(t, ctx, out) } func (suite *SearchValueQuorumTestSuite) TestQuorumUnspecified() { @@ -802,13 +776,13 @@ func (suite *SearchValueQuorumTestSuite) TestQuorumUnspecified() { out, err := suite.d.SearchValue(ctx, suite.key) require.NoError(t, err) - val := readItem(t, ctx, out) + val := kadtest.ReadItem(t, ctx, out) assert.Equal(t, suite.validValue, val) - val = readItem(t, ctx, out) + val = kadtest.ReadItem(t, ctx, out) assert.Equal(t, suite.betterValue, val) - assertClosed(t, ctx, out) + kadtest.AssertClosed(t, ctx, out) } func TestDHT_SearchValue_routing_option_returns_error(t *testing.T) { @@ -864,7 +838,7 @@ func TestDHT_SearchValue_stops_with_cancelled_context(t *testing.T) { valueChan, err := d1.SearchValue(cancelledCtx, "/"+namespaceIPNS+"/some-key") assert.NoError(t, err) - assertClosed(t, ctx, valueChan) + kadtest.AssertClosed(t, ctx, valueChan) } func TestDHT_SearchValue_has_record_locally(t *testing.T) { @@ -892,13 +866,13 @@ func TestDHT_SearchValue_has_record_locally(t *testing.T) { valChan, err := d1.SearchValue(ctx, key) require.NoError(t, err) - val := readItem(t, ctx, valChan) // from local store + val := kadtest.ReadItem(t, ctx, valChan) // from local store assert.Equal(t, validValue, val) - val = readItem(t, ctx, valChan) + val = kadtest.ReadItem(t, ctx, valChan) assert.Equal(t, betterValue, val) - assertClosed(t, ctx, valChan) + kadtest.AssertClosed(t, ctx, valChan) } func TestDHT_SearchValue_offline(t *testing.T) { @@ -914,10 +888,10 @@ func TestDHT_SearchValue_offline(t *testing.T) { valChan, err := d.SearchValue(ctx, key, routing.Offline) require.NoError(t, err) - val := readItem(t, ctx, valChan) + val := kadtest.ReadItem(t, ctx, valChan) assert.Equal(t, v, val) - assertClosed(t, ctx, valChan) + kadtest.AssertClosed(t, ctx, valChan) } func TestDHT_SearchValue_offline_not_found_locally(t *testing.T) {