Skip to content

Commit

Permalink
Make regression test deterministic
Browse files Browse the repository at this point in the history
  • Loading branch information
dennis-tra committed Oct 10, 2023
1 parent 628f799 commit de889e4
Show file tree
Hide file tree
Showing 4 changed files with 124 additions and 72 deletions.
35 changes: 35 additions & 0 deletions internal/coord/behaviour.go
Original file line number Diff line number Diff line change
Expand Up @@ -146,3 +146,38 @@ func (w *Waiter[E]) Close() {
func (w *Waiter[E]) Chan() <-chan WaiterEvent[E] {
return w.pending
}

// NotifyCloserHook implements the [NotifyCloser] interface and provides hooks
// into the Notify and Close calls by wrapping another [NotifyCloser]. This is
// intended to be used in testing.
type NotifyCloserHook[E BehaviourEvent] struct {
nc NotifyCloser[E]
BeforeNotify func(context.Context, E)
AfterNotify func(context.Context, E)
BeforeClose func()
AfterClose func()
}

var _ NotifyCloser[BehaviourEvent] = (*NotifyCloserHook[BehaviourEvent])(nil)

func NewNotifyCloserHook[E BehaviourEvent](nc NotifyCloser[E]) *NotifyCloserHook[E] {
return &NotifyCloserHook[E]{
nc: nc,
BeforeNotify: func(ctx context.Context, e E) {},
AfterNotify: func(ctx context.Context, e E) {},
BeforeClose: func() {},
AfterClose: func() {},
}
}

func (n *NotifyCloserHook[E]) Notify(ctx context.Context, ev E) {
n.BeforeNotify(ctx, ev)
n.nc.Notify(ctx, ev)
n.AfterNotify(ctx, ev)
}

func (n *NotifyCloserHook[E]) Close() {
n.BeforeClose()
n.nc.Close()
n.AfterClose()
}
37 changes: 22 additions & 15 deletions internal/coord/query_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ func TestPooledQuery_deadlock_regression(t *testing.T) {

// start query
waiter := NewWaiter[BehaviourEvent]()
wrappedWaiter := NewNotifyCloserHook[BehaviourEvent](waiter)

waiterDone := make(chan struct{})
waiterMsg := make(chan struct{})
Expand All @@ -121,7 +122,7 @@ func TestPooledQuery_deadlock_regression(t *testing.T) {
Target: msg.Target(),
Message: msg,
KnownClosestNodes: []kadt.PeerID{nodes[1].NodeID},
Notify: waiter,
Notify: wrappedWaiter,
NumResults: 0,
})

Expand All @@ -144,22 +145,33 @@ func TestPooledQuery_deadlock_regression(t *testing.T) {
ev, _ = c.queryBehaviour.Perform(ctx)
require.IsType(t, &EventOutboundSendMessage{}, ev)

hasLock := make(chan struct{})
var once sync.Once
wrappedWaiter.BeforeNotify = func(ctx context.Context, event BehaviourEvent) {
once.Do(func() {
require.IsType(t, &EventQueryProgressed{}, event) // verify test invariant
close(hasLock)
})
}

// Simulate a successful response from the new node. This node didn't return
// any new nodes to contact. This means the query pool behaviour will notify
// the waiter about a query progression and afterward about a finished
// query. Because (at the time of writing) the waiter has a channel buffer
// of 1, the channel cannot hold both events. At the same time, the waiter
// doesn't consume the messages because it's busy processing the previous
// query event (because we haven't released the blocking waiterMsg call above).
var wg sync.WaitGroup
wg.Add(1)
go func() {
wg.Done()
c.queryBehaviour.Notify(ctx, successMsg(nodes[2].NodeID))
}()
go c.queryBehaviour.Notify(ctx, successMsg(nodes[2].NodeID))

// wait until the above Notify call was handled by waiting until the hasLock
// channel was closed in the above BeforeNotify hook. If that hook is called
// we can be sure that the above Notify call has acquired the polled query
// behaviour's pendingMu lock.
kadtest.AssertClosed(t, ctx, hasLock)

wg.Wait()
<-waiterMsg
// Since we know that the pooled query behaviour holds the lock we can
// release the slow waiter by reading an item from the waiterMsg channel.
kadtest.ReadItem(t, ctx, waiterMsg)

// At this point, the waitForQuery QueryFunc callback returned a
// coordt.ErrSkipRemaining. This instructs the waitForQuery method to notify
Expand All @@ -168,10 +180,5 @@ func TestPooledQuery_deadlock_regression(t *testing.T) {
// lock on the pending events to process. Therefore, this notify call will
// also block. At the same time, the waiter cannot read the new messages
// from the query behaviour because it tries to notify it.

select {
case <-waiterDone:
case <-ctx.Done():
t.Fatalf("tiemout")
}
kadtest.AssertClosed(t, ctx, waiterDone)
}
36 changes: 36 additions & 0 deletions internal/kadtest/chan.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
package kadtest

import (
"context"
"testing"

"github.com/stretchr/testify/assert"

"github.com/stretchr/testify/require"
)

func ReadItem[T any](t testing.TB, ctx context.Context, c <-chan T) T {
t.Helper()

select {
case val, more := <-c:
require.True(t, more, "channel closed unexpectedly")
return val
case <-ctx.Done():
t.Fatal("timeout reading item")
return *new(T)
}
}

// AssertClosed triggers a test failure if the given channel was not closed but
// carried more values or a timeout occurs (given by the context).
func AssertClosed[T any](t testing.TB, ctx context.Context, c <-chan T) {
t.Helper()

select {
case _, more := <-c:
assert.False(t, more)
case <-ctx.Done():
t.Fatal("timeout closing channel")
}
}
88 changes: 31 additions & 57 deletions routing_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import (
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"

"github.com/plprobelab/zikade/internal/kadtest"
kadtest "github.com/plprobelab/zikade/internal/kadtest"
"github.com/plprobelab/zikade/kadt"
)

Expand Down Expand Up @@ -289,7 +289,7 @@ func TestDHT_FindProvidersAsync_empty_routing_table(t *testing.T) {
c := newRandomContent(t)

out := d.FindProvidersAsync(ctx, c, 1)
assertClosed(t, ctx, out)
kadtest.AssertClosed(t, ctx, out)
}

func TestDHT_FindProvidersAsync_dht_does_not_support_providers(t *testing.T) {
Expand All @@ -300,7 +300,7 @@ func TestDHT_FindProvidersAsync_dht_does_not_support_providers(t *testing.T) {
delete(d.backends, namespaceProviders)

out := d.FindProvidersAsync(ctx, newRandomContent(t), 1)
assertClosed(t, ctx, out)
kadtest.AssertClosed(t, ctx, out)
}

func TestDHT_FindProvidersAsync_providers_stored_locally(t *testing.T) {
Expand All @@ -315,10 +315,10 @@ func TestDHT_FindProvidersAsync_providers_stored_locally(t *testing.T) {

out := d.FindProvidersAsync(ctx, c, 1)

val := readItem(t, ctx, out)
val := kadtest.ReadItem(t, ctx, out)
assert.Equal(t, provider.ID, val.ID)

assertClosed(t, ctx, out)
kadtest.AssertClosed(t, ctx, out)
}

func TestDHT_FindProvidersAsync_returns_only_count_from_local_store(t *testing.T) {
Expand Down Expand Up @@ -376,10 +376,10 @@ func TestDHT_FindProvidersAsync_queries_other_peers(t *testing.T) {

out := d1.FindProvidersAsync(ctx, c, 1)

val := readItem(t, ctx, out)
val := kadtest.ReadItem(t, ctx, out)
assert.Equal(t, provider.ID, val.ID)

assertClosed(t, ctx, out)
kadtest.AssertClosed(t, ctx, out)
}

func TestDHT_FindProvidersAsync_respects_cancelled_context_for_local_query(t *testing.T) {
Expand Down Expand Up @@ -493,15 +493,15 @@ func TestDHT_FindProvidersAsync_datastore_error(t *testing.T) {
be.datastore = dstore

out := d.FindProvidersAsync(ctx, newRandomContent(t), 0)
assertClosed(t, ctx, out)
kadtest.AssertClosed(t, ctx, out)
}

func TestDHT_FindProvidersAsync_invalid_key(t *testing.T) {
ctx := kadtest.CtxShort(t)
d := newTestDHT(t)

out := d.FindProvidersAsync(ctx, cid.Cid{}, 0)
assertClosed(t, ctx, out)
kadtest.AssertClosed(t, ctx, out)
}

func TestDHT_GetValue_happy_path(t *testing.T) {
Expand Down Expand Up @@ -563,32 +563,6 @@ func TestDHT_GetValue_returns_not_found_error(t *testing.T) {
assert.Nil(t, valueChan)
}

// assertClosed triggers a test failure if the given channel was not closed but
// carried more values or a timeout occurs (given by the context).
func assertClosed[T any](t testing.TB, ctx context.Context, c <-chan T) {
t.Helper()

select {
case _, more := <-c:
assert.False(t, more)
case <-ctx.Done():
t.Fatal("timeout closing channel")
}
}

func readItem[T any](t testing.TB, ctx context.Context, c <-chan T) T {
t.Helper()

select {
case val, more := <-c:
require.True(t, more, "channel closed unexpectedly")
return val
case <-ctx.Done():
t.Fatal("timeout reading item")
return *new(T)
}
}

func TestDHT_SearchValue_simple(t *testing.T) {
// Test setup:
// There is just one other server that returns a valid value.
Expand All @@ -608,10 +582,10 @@ func TestDHT_SearchValue_simple(t *testing.T) {
valChan, err := d1.SearchValue(ctx, key)
require.NoError(t, err)

val := readItem(t, ctx, valChan)
val := kadtest.ReadItem(t, ctx, valChan)
assert.Equal(t, v, val)

assertClosed(t, ctx, valChan)
kadtest.AssertClosed(t, ctx, valChan)
}

func TestDHT_SearchValue_returns_best_values(t *testing.T) {
Expand Down Expand Up @@ -657,13 +631,13 @@ func TestDHT_SearchValue_returns_best_values(t *testing.T) {
valChan, err := d1.SearchValue(ctx, key)
require.NoError(t, err)

val := readItem(t, ctx, valChan)
val := kadtest.ReadItem(t, ctx, valChan)
assert.Equal(t, validValue, val)

val = readItem(t, ctx, valChan)
val = kadtest.ReadItem(t, ctx, valChan)
assert.Equal(t, betterValue, val)

assertClosed(t, ctx, valChan)
kadtest.AssertClosed(t, ctx, valChan)
}

// In order for 'go test' to run this suite, we need to create
Expand Down Expand Up @@ -756,10 +730,10 @@ func (suite *SearchValueQuorumTestSuite) TestQuorumReachedPrematurely() {
out, err := suite.d.SearchValue(ctx, suite.key, RoutingQuorum(3))
require.NoError(t, err)

val := readItem(t, ctx, out)
val := kadtest.ReadItem(t, ctx, out)
assert.Equal(t, suite.validValue, val)

assertClosed(t, ctx, out)
kadtest.AssertClosed(t, ctx, out)
}

func (suite *SearchValueQuorumTestSuite) TestQuorumReachedAfterDiscoveryOfBetter() {
Expand All @@ -768,13 +742,13 @@ func (suite *SearchValueQuorumTestSuite) TestQuorumReachedAfterDiscoveryOfBetter
out, err := suite.d.SearchValue(ctx, suite.key, RoutingQuorum(5))
require.NoError(t, err)

val := readItem(t, ctx, out)
val := kadtest.ReadItem(t, ctx, out)
assert.Equal(t, suite.validValue, val)

val = readItem(t, ctx, out)
val = kadtest.ReadItem(t, ctx, out)
assert.Equal(t, suite.betterValue, val)

assertClosed(t, ctx, out)
kadtest.AssertClosed(t, ctx, out)
}

func (suite *SearchValueQuorumTestSuite) TestQuorumZero() {
Expand All @@ -785,13 +759,13 @@ func (suite *SearchValueQuorumTestSuite) TestQuorumZero() {
out, err := suite.d.SearchValue(ctx, suite.key, RoutingQuorum(0))
require.NoError(t, err)

val := readItem(t, ctx, out)
val := kadtest.ReadItem(t, ctx, out)
assert.Equal(t, suite.validValue, val)

val = readItem(t, ctx, out)
val = kadtest.ReadItem(t, ctx, out)
assert.Equal(t, suite.betterValue, val)

assertClosed(t, ctx, out)
kadtest.AssertClosed(t, ctx, out)
}

func (suite *SearchValueQuorumTestSuite) TestQuorumUnspecified() {
Expand All @@ -802,13 +776,13 @@ func (suite *SearchValueQuorumTestSuite) TestQuorumUnspecified() {
out, err := suite.d.SearchValue(ctx, suite.key)
require.NoError(t, err)

val := readItem(t, ctx, out)
val := kadtest.ReadItem(t, ctx, out)
assert.Equal(t, suite.validValue, val)

val = readItem(t, ctx, out)
val = kadtest.ReadItem(t, ctx, out)
assert.Equal(t, suite.betterValue, val)

assertClosed(t, ctx, out)
kadtest.AssertClosed(t, ctx, out)
}

func TestDHT_SearchValue_routing_option_returns_error(t *testing.T) {
Expand Down Expand Up @@ -864,7 +838,7 @@ func TestDHT_SearchValue_stops_with_cancelled_context(t *testing.T) {

valueChan, err := d1.SearchValue(cancelledCtx, "/"+namespaceIPNS+"/some-key")
assert.NoError(t, err)
assertClosed(t, ctx, valueChan)
kadtest.AssertClosed(t, ctx, valueChan)
}

func TestDHT_SearchValue_has_record_locally(t *testing.T) {
Expand Down Expand Up @@ -892,13 +866,13 @@ func TestDHT_SearchValue_has_record_locally(t *testing.T) {
valChan, err := d1.SearchValue(ctx, key)
require.NoError(t, err)

val := readItem(t, ctx, valChan) // from local store
val := kadtest.ReadItem(t, ctx, valChan) // from local store
assert.Equal(t, validValue, val)

val = readItem(t, ctx, valChan)
val = kadtest.ReadItem(t, ctx, valChan)
assert.Equal(t, betterValue, val)

assertClosed(t, ctx, valChan)
kadtest.AssertClosed(t, ctx, valChan)
}

func TestDHT_SearchValue_offline(t *testing.T) {
Expand All @@ -914,10 +888,10 @@ func TestDHT_SearchValue_offline(t *testing.T) {
valChan, err := d.SearchValue(ctx, key, routing.Offline)
require.NoError(t, err)

val := readItem(t, ctx, valChan)
val := kadtest.ReadItem(t, ctx, valChan)
assert.Equal(t, v, val)

assertClosed(t, ctx, valChan)
kadtest.AssertClosed(t, ctx, valChan)
}

func TestDHT_SearchValue_offline_not_found_locally(t *testing.T) {
Expand Down

0 comments on commit de889e4

Please sign in to comment.