diff --git a/v2/backend.go b/v2/backend.go index 4e5d313f..cccb2c36 100644 --- a/v2/backend.go +++ b/v2/backend.go @@ -52,7 +52,8 @@ type Backend interface { Store(ctx context.Context, key string, value any) (any, error) // Fetch returns the record for the given path or a [ds.ErrNotFound] if it - // wasn't found or another error if any occurred. + // wasn't found or another error if any occurred. key won't contain the + // namespace prefix. Fetch(ctx context.Context, key string) (any, error) } diff --git a/v2/config_test.go b/v2/config_test.go index 739216ab..ad84b8d4 100644 --- a/v2/config_test.go +++ b/v2/config_test.go @@ -78,6 +78,33 @@ func TestConfig_Validate(t *testing.T) { assert.Error(t, cfg.Validate()) }) + t.Run("backends for ipfs protocol (public key missing)", func(t *testing.T) { + cfg := DefaultConfig() + cfg.ProtocolID = ProtocolIPFS + cfg.Backends[namespaceProviders] = &RecordBackend{} + cfg.Backends[namespaceIPNS] = &RecordBackend{} + cfg.Backends["another"] = &RecordBackend{} + assert.Error(t, cfg.Validate()) + }) + + t.Run("backends for ipfs protocol (ipns missing)", func(t *testing.T) { + cfg := DefaultConfig() + cfg.ProtocolID = ProtocolIPFS + cfg.Backends[namespaceProviders] = &RecordBackend{} + cfg.Backends["another"] = &RecordBackend{} + cfg.Backends[namespacePublicKey] = &RecordBackend{} + assert.Error(t, cfg.Validate()) + }) + + t.Run("backends for ipfs protocol (providers missing)", func(t *testing.T) { + cfg := DefaultConfig() + cfg.ProtocolID = ProtocolIPFS + cfg.Backends["another"] = &RecordBackend{} + cfg.Backends[namespaceIPNS] = &RecordBackend{} + cfg.Backends[namespacePublicKey] = &RecordBackend{} + assert.Error(t, cfg.Validate()) + }) + t.Run("nil address filter", func(t *testing.T) { cfg := DefaultConfig() cfg.AddressFilter = nil diff --git a/v2/handlers_test.go b/v2/handlers_test.go index a94816e1..6910c347 100644 --- a/v2/handlers_test.go +++ b/v2/handlers_test.go @@ -34,6 +34,7 @@ var rng = rand.New(rand.NewSource(1337)) func newTestDHT(t testing.TB) *DHT { cfg := DefaultConfig() + cfg.Logger = devnull return newTestDHTWithConfig(t, cfg) } diff --git a/v2/internal/coord/coordinator.go b/v2/internal/coord/coordinator.go index afe712d0..b167c840 100644 --- a/v2/internal/coord/coordinator.go +++ b/v2/internal/coord/coordinator.go @@ -385,7 +385,7 @@ func (c *Coordinator) QueryMessage(ctx context.Context, msg *pb.Message, fn coor defer cancel() if numResults < 1 { - numResults = 20 + numResults = 20 // TODO: parameterize } seeds, err := c.GetClosestNodes(ctx, msg.Target(), numResults) @@ -424,7 +424,7 @@ func (c *Coordinator) BroadcastRecord(ctx context.Context, msg *pb.Message) erro ctx, cancel := context.WithCancel(ctx) defer cancel() - seeds, err := c.GetClosestNodes(ctx, msg.Target(), 20) + seeds, err := c.GetClosestNodes(ctx, msg.Target(), 20) // TODO: parameterize if err != nil { return err } @@ -449,9 +449,7 @@ func (c *Coordinator) BroadcastRecord(ctx context.Context, msg *pb.Message) erro // queue the start of the query c.brdcstBehaviour.Notify(ctx, cmd) - contacted, errs, err := c.waitForBroadcast(ctx, waiter) - fmt.Println(contacted) - fmt.Println(errs) + _, _, err = c.waitForBroadcast(ctx, waiter) return err } diff --git a/v2/internal/coord/query/query.go b/v2/internal/coord/query/query.go index 00168082..77a4dae8 100644 --- a/v2/internal/coord/query/query.go +++ b/v2/internal/coord/query/query.go @@ -84,7 +84,7 @@ type Query[K kad.Key[K], N kad.NodeID[K], M coordt.Message] struct { findCloser bool stats QueryStats - // finished indicates that that the query has completed its work or has been stopped. + // finished indicates that the query has completed its work or has been stopped. finished bool // targetNodes is the set of responsive nodes thought to be closest to the target. diff --git a/v2/routing.go b/v2/routing.go index eec85c30..756a3d48 100644 --- a/v2/routing.go +++ b/v2/routing.go @@ -16,6 +16,7 @@ import ( "github.com/libp2p/go-libp2p/core/routing" "go.opentelemetry.io/otel/attribute" otel "go.opentelemetry.io/otel/trace" + "golang.org/x/exp/slog" "github.com/libp2p/go-libp2p-kad-dht/v2/internal/coord/coordt" "github.com/libp2p/go-libp2p-kad-dht/v2/kadt" @@ -110,20 +111,100 @@ func (d *DHT) Provide(ctx context.Context, c cid.Cid, brdcst bool) error { } func (d *DHT) FindProvidersAsync(ctx context.Context, c cid.Cid, count int) <-chan peer.AddrInfo { - _, span := d.tele.Tracer.Start(ctx, "DHT.FindProvidersAsync", otel.WithAttributes(attribute.String("cid", c.String()), attribute.Int("count", count))) + peerOut := make(chan peer.AddrInfo) + go d.findProvidersAsyncRoutine(ctx, c, count, peerOut) + return peerOut +} + +func (d *DHT) findProvidersAsyncRoutine(ctx context.Context, c cid.Cid, count int, out chan peer.AddrInfo) { + _, span := d.tele.Tracer.Start(ctx, "DHT.findProvidersAsyncRoutine", otel.WithAttributes(attribute.String("cid", c.String()), attribute.Int("count", count))) defer span.End() - // verify if this DHT supports provider records by checking if a "providers" - // backend is registered. - _, found := d.backends[namespaceProviders] + defer close(out) + + // verify if this DHT supports provider records by checking + // if a "providers" backend is registered. + b, found := d.backends[namespaceProviders] if !found || !c.Defined() { - peerOut := make(chan peer.AddrInfo) - close(peerOut) - return peerOut + span.RecordError(fmt.Errorf("no providers backend registered or CID undefined")) + return } - // TODO reach out to Zikade - panic("implement me") + // first fetch the record locally + stored, err := b.Fetch(ctx, string(c.Hash())) + if err != nil { + span.RecordError(err) + d.log.Warn("Fetching value from provider store", slog.String("cid", c.String()), slog.String("err", err.Error())) + return + } + + ps, ok := stored.(*providerSet) + if !ok { + span.RecordError(err) + d.log.Warn("Stored value is not a provider set", slog.String("cid", c.String()), slog.String("type", fmt.Sprintf("%T", stored))) + return + } + + // send all providers onto the out channel until the desired count + // was reached. If no count was specified, continue with network lookup. + providers := map[peer.ID]struct{}{} + for _, provider := range ps.providers { + providers[provider.ID] = struct{}{} + + select { + case <-ctx.Done(): + return + case out <- provider: + } + + if count != 0 && len(providers) == count { + return + } + } + + // Craft message to send to other peers + msg := &pb.Message{ + Type: pb.Message_GET_PROVIDERS, + Key: c.Hash(), + } + + // handle node response + fn := func(ctx context.Context, id kadt.PeerID, resp *pb.Message, stats coordt.QueryStats) error { + // loop through all providers that the remote peer returned + for _, provider := range resp.ProviderAddrInfos() { + + // if we had already sent that peer on the channel -> do nothing + if _, found := providers[provider.ID]; found { + continue + } + + // keep track that we will have sent this peer on the channel + providers[provider.ID] = struct{}{} + + // actually send the provider information to the user + select { + case <-ctx.Done(): + return coordt.ErrSkipRemaining + case out <- provider: + } + + // if count is 0, we will wait until the query has exhausted the keyspace + // if count isn't 0, we will stop if the number of providers we have sent + // equals the number that the user has requested. + if count != 0 && len(providers) == count { + return coordt.ErrSkipRemaining + } + } + + return nil + } + + _, err = d.kad.QueryMessage(ctx, msg, fn, 20) // TODO: parameterize + if err != nil { + span.RecordError(err) + d.log.Warn("Failed querying", slog.String("cid", c.String()), slog.String("err", err.Error())) + return + } } // PutValue satisfies the [routing.Routing] interface and will add the given diff --git a/v2/routing_test.go b/v2/routing_test.go index 8647b56e..d8aa1e90 100644 --- a/v2/routing_test.go +++ b/v2/routing_test.go @@ -1,11 +1,18 @@ package dht import ( + "context" + "crypto/rand" + "crypto/sha256" + "fmt" "testing" + "github.com/ipfs/go-cid" + "github.com/ipfs/go-datastore/failstore" "github.com/libp2p/go-libp2p/core/crypto" "github.com/libp2p/go-libp2p/core/peer" "github.com/libp2p/go-libp2p/core/routing" + mh "github.com/multiformats/go-multihash" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -93,3 +100,274 @@ func TestGetValueOnePeer(t *testing.T) { require.Equal(t, v, val) } + +// NewRandomContent reads 1024 bytes from crypto/rand and builds a content struct. +func newRandomContent(t testing.TB) cid.Cid { + raw := make([]byte, 1024) + _, err := rand.Read(raw) + require.NoError(t, err) + + hash := sha256.New() + hash.Write(raw) + + mhash, err := mh.Encode(hash.Sum(nil), mh.SHA2_256) + require.NoError(t, err) + + return cid.NewCidV0(mhash) +} + +func TestDHT_FindProvidersAsync_empty_routing_table(t *testing.T) { + ctx := kadtest.CtxShort(t) + d := newTestDHT(t) + c := newRandomContent(t) + + out := d.FindProvidersAsync(ctx, c, 1) + select { + case _, more := <-out: + require.False(t, more) + case <-ctx.Done(): + t.Fatal("timeout") + } +} + +func TestDHT_FindProvidersAsync_dht_does_not_support_providers(t *testing.T) { + ctx := kadtest.CtxShort(t) + d := newTestDHT(t) + fillRoutingTable(t, d, 250) + + delete(d.backends, namespaceProviders) + + out := d.FindProvidersAsync(ctx, newRandomContent(t), 1) + select { + case _, more := <-out: + require.False(t, more) + case <-ctx.Done(): + t.Fatal("timeout") + } +} + +func TestDHT_FindProvidersAsync_providers_stored_locally(t *testing.T) { + ctx := kadtest.CtxShort(t) + d := newTestDHT(t) + fillRoutingTable(t, d, 250) + + c := newRandomContent(t) + provider := peer.AddrInfo{ID: newPeerID(t)} + _, err := d.backends[namespaceProviders].Store(ctx, string(c.Hash()), provider) + require.NoError(t, err) + + out := d.FindProvidersAsync(ctx, c, 1) + for { + select { + case p, more := <-out: + if !more { + return + } + assert.Equal(t, provider.ID, p.ID) + case <-ctx.Done(): + t.Fatal("timeout") + } + } +} + +func TestDHT_FindProvidersAsync_returns_only_count_from_local_store(t *testing.T) { + ctx := kadtest.CtxShort(t) + d := newTestDHT(t) + fillRoutingTable(t, d, 250) + + c := newRandomContent(t) + + storedCount := 5 + requestedCount := 3 + + // invariant for this test + assert.Less(t, requestedCount, storedCount) + + for i := 0; i < storedCount; i++ { + provider := peer.AddrInfo{ID: newPeerID(t)} + _, err := d.backends[namespaceProviders].Store(ctx, string(c.Hash()), provider) + require.NoError(t, err) + } + + out := d.FindProvidersAsync(ctx, c, requestedCount) + + returnedCount := 0 +LOOP: + for { + select { + case _, more := <-out: + if !more { + break LOOP + } + returnedCount += 1 + case <-ctx.Done(): + t.Fatal("timeout") + } + } + assert.Equal(t, requestedCount, returnedCount) +} + +func TestDHT_FindProvidersAsync_queries_other_peers(t *testing.T) { + ctx := kadtest.CtxShort(t) + + c := newRandomContent(t) + + top := NewTopology(t) + d1 := top.AddServer(nil) + d2 := top.AddServer(nil) + d3 := top.AddServer(nil) + + top.ConnectChain(ctx, d1, d2, d3) + + provider := peer.AddrInfo{ID: newPeerID(t)} + _, err := d3.backends[namespaceProviders].Store(ctx, string(c.Hash()), provider) + require.NoError(t, err) + + out := d1.FindProvidersAsync(ctx, c, 1) + select { + case p, more := <-out: + require.True(t, more) + assert.Equal(t, provider.ID, p.ID) + case <-ctx.Done(): + t.Fatal("timeout") + } + + select { + case _, more := <-out: + assert.False(t, more) + case <-ctx.Done(): + t.Fatal("timeout") + } +} + +func TestDHT_FindProvidersAsync_respects_cancelled_context_for_local_query(t *testing.T) { + // Test strategy: + // We let d know about providersCount providers for the CID c + // Then we ask it to find providers but pass it a cancelled context. + // We assert that we are sending on the channel while also respecting a + // cancelled context by checking if the number of returned providers is + // less than the number of providers d knows about. Since it's random + // which channel gets selected on, providersCount must be a significantly + // large. This is a statistical test, and we should observe if it's a flaky + // one. + ctx := kadtest.CtxShort(t) + d := newTestDHT(t) + + c := newRandomContent(t) + + providersCount := 50 + for i := 0; i < providersCount; i++ { + provider := peer.AddrInfo{ID: newPeerID(t)} + _, err := d.backends[namespaceProviders].Store(ctx, string(c.Hash()), provider) + require.NoError(t, err) + } + + cancelledCtx, cancel := context.WithCancel(ctx) + cancel() + + out := d.FindProvidersAsync(cancelledCtx, c, 0) + + returnedCount := 0 +LOOP: + for { + select { + case _, more := <-out: + if !more { + break LOOP + } + returnedCount += 1 + case <-ctx.Done(): + t.Fatal("timeout") + } + } + assert.Less(t, returnedCount, providersCount) +} + +func TestDHT_FindProvidersAsync_does_not_return_same_record_twice(t *testing.T) { + // Test setup: + // There are two providers in the network for CID c. + // d1 has information about one provider locally. + // d2 has information about two providers. One of which is the one d1 knew about. + // We assert that the locally known provider is only returned once. + // The query should run until exhaustion. + ctx := kadtest.CtxShort(t) + + c := newRandomContent(t) + + top := NewTopology(t) + d1 := top.AddServer(nil) + d2 := top.AddServer(nil) + + top.Connect(ctx, d1, d2) + + provider1 := peer.AddrInfo{ID: newPeerID(t)} + provider2 := peer.AddrInfo{ID: newPeerID(t)} + + // store provider1 with d1 + _, err := d1.backends[namespaceProviders].Store(ctx, string(c.Hash()), provider1) + require.NoError(t, err) + + // store provider1 with d2 + _, err = d2.backends[namespaceProviders].Store(ctx, string(c.Hash()), provider1) + require.NoError(t, err) + + // store provider2 with d2 + _, err = d2.backends[namespaceProviders].Store(ctx, string(c.Hash()), provider2) + require.NoError(t, err) + + out := d1.FindProvidersAsync(ctx, c, 0) + count := 0 +LOOP: + for { + select { + case p, more := <-out: + if !more { + break LOOP + } + count += 1 + assert.True(t, p.ID == provider1.ID || p.ID == provider2.ID) + case <-ctx.Done(): + t.Fatal("timeout") + } + } + assert.Equal(t, 2, count) +} + +func TestDHT_FindProvidersAsync_datastore_error(t *testing.T) { + ctx := kadtest.CtxShort(t) + d := newTestDHT(t) + + // construct a datastore that fails for any operation + memStore, err := InMemoryDatastore() + require.NoError(t, err) + + dstore := failstore.NewFailstore(memStore, func(s string) error { + return fmt.Errorf("some error") + }) + + be, err := typedBackend[*ProvidersBackend](d, namespaceProviders) + require.NoError(t, err) + + be.datastore = dstore + + out := d.FindProvidersAsync(ctx, newRandomContent(t), 0) + select { + case _, more := <-out: + assert.False(t, more) + case <-ctx.Done(): + t.Fatal("timeout") + } +} + +func TestDHT_FindProvidersAsync_invalid_key(t *testing.T) { + ctx := kadtest.CtxShort(t) + d := newTestDHT(t) + + out := d.FindProvidersAsync(ctx, cid.Cid{}, 0) + select { + case _, more := <-out: + assert.False(t, more) + case <-ctx.Done(): + t.Fatal("timeout") + } +}