Skip to content

Commit

Permalink
pkg, internal: add pagination to list requests (#2016)
Browse files Browse the repository at this point in the history
* pkg, internal: add pagination to list requests

* address linter

* rename All to ListAll

* rename file

* beautify signature

* fix unit tests
  • Loading branch information
s-urbaniak authored Jan 3, 2025
1 parent 97cdffd commit 66fafc9
Show file tree
Hide file tree
Showing 19 changed files with 318 additions and 232 deletions.
1 change: 1 addition & 0 deletions internal/translation/customroles/custom_roles.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ func (s *CustomRoles) Get(ctx context.Context, projectID string, roleName string
}

func (s *CustomRoles) List(ctx context.Context, projectID string) ([]CustomRole, error) {
// custom database roles does not offer paginated resources.
atlasRoles, _, err := s.roleAPI.ListCustomDatabaseRoles(ctx, projectID).Execute()
if err != nil {
return nil, fmt.Errorf("failed to list custom roles from Atlas: %w", err)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,14 @@ package datafederation
import (
"context"
"fmt"
"net/http"

"go.mongodb.org/atlas-sdk/v20231115008/admin"
"go.uber.org/zap"
"k8s.io/apimachinery/pkg/types"

"github.com/mongodb/mongodb-atlas-kubernetes/v2/internal/translation"
"github.com/mongodb/mongodb-atlas-kubernetes/v2/internal/translation/paging"
"github.com/mongodb/mongodb-atlas-kubernetes/v2/pkg/controller/atlas"
)

Expand All @@ -31,12 +33,14 @@ func NewDatafederationPrivateEndpointService(ctx context.Context, provider atlas
}

func (d *DatafederationPrivateEndpoints) List(ctx context.Context, projectID string) ([]*DatafederationPrivateEndpointEntry, error) {
paginatedResponse, _, err := d.api.ListDataFederationPrivateEndpoints(ctx, projectID).Execute()
results, err := paging.ListAll(ctx, func(ctx context.Context, pageNum int) (paging.Response[admin.PrivateNetworkEndpointIdEntry], *http.Response, error) {
return d.api.ListDataFederationPrivateEndpoints(ctx, projectID).PageNum(pageNum).Execute()
})
if err != nil {
return nil, fmt.Errorf("failed to list data federation private endpoints from Atlas: %w", err)
}

return endpointsFromAtlas(paginatedResponse.GetResults(), projectID)
return endpointsFromAtlas(results, projectID)
}

func (d *DatafederationPrivateEndpoints) Create(ctx context.Context, aep *DatafederationPrivateEndpointEntry) error {
Expand Down
9 changes: 7 additions & 2 deletions internal/translation/ipaccesslist/ipaccesslist.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,11 @@ package ipaccesslist
import (
"context"
"fmt"
"net/http"

"go.mongodb.org/atlas-sdk/v20231115008/admin"

"github.com/mongodb/mongodb-atlas-kubernetes/v2/internal/translation/paging"
)

type IPAccessListService interface {
Expand All @@ -19,12 +22,14 @@ type IPAccessList struct {
}

func (i *IPAccessList) List(ctx context.Context, projectID string) (IPAccessEntries, error) {
netPermResult, _, err := i.ipAccessListAPI.ListProjectIpAccessLists(ctx, projectID).Execute()
netPermResult, err := paging.ListAll(ctx, func(ctx context.Context, pageNum int) (paging.Response[admin.NetworkPermissionEntry], *http.Response, error) {
return i.ipAccessListAPI.ListProjectIpAccessLists(ctx, projectID).PageNum(pageNum).Execute()
})
if err != nil {
return nil, fmt.Errorf("failed to get ip access list from Atlas: %w", err)
}

return fromAtlas(netPermResult.GetResults()), nil
return fromAtlas(netPermResult), nil
}

func (i *IPAccessList) Add(ctx context.Context, projectID string, entries IPAccessEntries) error {
Expand Down
32 changes: 12 additions & 20 deletions internal/translation/networkpeering/networkpeering.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@ package networkpeering
import (
"context"
"fmt"
"net/http"

"go.mongodb.org/atlas-sdk/v20231115008/admin"

"github.com/mongodb/mongodb-atlas-kubernetes/v2/internal/pointer"
"github.com/mongodb/mongodb-atlas-kubernetes/v2/internal/translation/paging"
"github.com/mongodb/mongodb-atlas-kubernetes/v2/pkg/api/v1/provider"
)

Expand Down Expand Up @@ -66,28 +68,18 @@ func (np *networkPeeringService) ListPeers(ctx context.Context, projectID string
}

func (np *networkPeeringService) listPeersForProvider(ctx context.Context, projectID string, providerName provider.ProviderName) ([]NetworkPeer, error) {
results := []NetworkPeer{}
pageNum := 1
listOpts := &admin.ListPeeringConnectionsApiParams{
GroupId: projectID,
ProviderName: admin.PtrString(string(providerName)),
PageNum: pointer.MakePtr(pageNum),
}
for {
page, _, err := np.peeringAPI.ListPeeringConnectionsWithParams(ctx, listOpts).Execute()
if err != nil {
return nil, fmt.Errorf("failed to list network peers: %w", err)
}
list, err := fromAtlasConnectionList(page.GetResults())
if err != nil {
return nil, fmt.Errorf("failed to convert results to peer list: %w", err)
}
results = append(results, list...)
if len(results) >= page.GetTotalCount() {
return results, nil
results, err := paging.ListAll(ctx, func(ctx context.Context, pageNum int) (paging.Response[admin.BaseNetworkPeeringConnectionSettings], *http.Response, error) {
p := &admin.ListPeeringConnectionsApiParams{
GroupId: projectID,
ProviderName: admin.PtrString(string(providerName)),
}
pageNum += 1
return np.peeringAPI.ListPeeringConnectionsWithParams(ctx, p).PageNum(pageNum).Execute()
})
if err != nil {
return nil, fmt.Errorf("failed to list network peers: %w", err)
}

return fromAtlasConnectionList(results)
}

func (np *networkPeeringService) DeletePeer(ctx context.Context, projectID, peerID string) error {
Expand Down
38 changes: 38 additions & 0 deletions internal/translation/paging/list.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
package paging

import (
"context"
"errors"
"net/http"
)

// Response is the paginated response containing the current page results and the total count.
// It is implemented by all supported SDK versions.
type Response[T any] interface {
GetResults() []T
GetTotalCount() int
}

// ListAll invokes the given pagination list function multiple times until the total count of responses is gathered.
// Once done, all paginated responses are returned.
// If an error occurs, the first error occurrence will be returned.
//
// This is taken over from https://github.com/mongodb/terraform-provider-mongodbatlas/blob/a5581ebb274dbcaffd43d330c5bfbbb329cae51d/internal/common/dsschema/page_request.go#L14-L31.
func ListAll[T any](ctx context.Context, listFunc func(ctx context.Context, pageNum int) (Response[T], *http.Response, error)) ([]T, error) {
var results []T
for currentPage := 1; ; currentPage++ {
resp, _, err := listFunc(ctx, currentPage)
if err != nil {
return nil, err
}
if resp == nil {
return nil, errors.New("no response")
}
currentResults := resp.GetResults()
results = append(results, currentResults...)
if len(currentResults) == 0 || len(results) >= resp.GetTotalCount() {
break
}
}
return results, nil
}
146 changes: 146 additions & 0 deletions internal/translation/paging/list_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
package paging

import (
"context"
"net/http"
"testing"

"github.com/stretchr/testify/require"
)

type page struct {
results []string
totalCount int
}

func (r *page) GetResults() []string {
if r == nil {
return nil
}
return r.results
}

func (r *page) GetTotalCount() int {
if r == nil {
return 0
}
return r.totalCount
}

func responder(pages []*page) func(ctx context.Context, pageNum int) (Response[string], *http.Response, error) {
totalCount := 0
for _, p := range pages {
if p == nil {
continue
}
totalCount = totalCount + len(p.results)
}

for _, p := range pages {
if p == nil {
continue
}
p.totalCount = totalCount
}

return func(ctx context.Context, pageNum int) (Response[string], *http.Response, error) {
if len(pages) == 0 {
return nil, nil, nil
}
return pages[pageNum-1], nil, nil
}
}

func TestAll(t *testing.T) {
ctx := context.Background()

for _, tc := range []struct {
name string
pages []*page
wantErr string
wantResult []string
}{
{
name: "no response",
wantErr: "no response",
},
{
name: "empty response",
pages: []*page{},
wantErr: "no response",
},
{
name: "empty results",
pages: []*page{
{results: []string{}},
},
wantResult: nil,
},
{
name: "single result",
pages: []*page{
{results: []string{"a"}},
},
wantResult: []string{"a"},
},
{
name: "multiple results",
pages: []*page{
{results: []string{"a", "b"}},
},
wantResult: []string{"a", "b"},
},
{
name: "one additional nil page",
pages: []*page{
{results: []string{"a", "b"}},
nil,
},
wantResult: []string{"a", "b"},
},
{
name: "one additional empty results page",
pages: []*page{
{results: []string{"a", "b"}},
{results: []string{}},
},
wantResult: []string{"a", "b"},
},
{
name: "multiple results",
pages: []*page{
{results: []string{"a", "b"}},
{results: []string{"c", "d"}},
},
wantResult: []string{"a", "b", "c", "d"},
},
{
name: "multiple results with nil page",
pages: []*page{
{results: []string{"a", "b"}},
nil,
{results: []string{"c", "d"}},
},
wantResult: []string{"a", "b"},
},
{
name: "multiple results with empty results",
pages: []*page{
{results: []string{"a", "b"}},
{results: []string{}},
{results: []string{"c", "d"}},
},
wantResult: []string{"a", "b"},
},
} {
t.Run(tc.name, func(t *testing.T) {
response, err := ListAll(ctx, responder(tc.pages))
gotErr := ""
if err != nil {
gotErr = err.Error()
}
require.Equal(t, tc.wantErr, gotErr)
require.Equal(t, tc.wantResult, response)
})
}
}
7 changes: 5 additions & 2 deletions internal/translation/teams/teams.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (

"go.mongodb.org/atlas-sdk/v20231115008/admin"

"github.com/mongodb/mongodb-atlas-kubernetes/v2/internal/translation/paging"
akov2 "github.com/mongodb/mongodb-atlas-kubernetes/v2/pkg/api/v1"
)

Expand Down Expand Up @@ -49,11 +50,13 @@ func NewTeamsAPIService(teamAPI admin.TeamsApi, userAPI admin.MongoDBCloudUsersA
}

func (tm *TeamsAPI) ListProjectTeams(ctx context.Context, projectID string) ([]AssignedTeam, error) {
atlasAssignedTeams, _, err := tm.teamsAPI.ListProjectTeams(ctx, projectID).Execute()
atlasAssignedTeams, err := paging.ListAll(ctx, func(ctx context.Context, pageNum int) (paging.Response[admin.TeamRole], *http.Response, error) {
return tm.teamsAPI.ListProjectTeams(ctx, projectID).PageNum(pageNum).Execute()
})
if err != nil {
return nil, fmt.Errorf("failed to get project team list from Atlas: %w", err)
}
return TeamRolesFromAtlas(atlasAssignedTeams.GetResults()), err
return TeamRolesFromAtlas(atlasAssignedTeams), err
}

func (tm *TeamsAPI) GetTeamByName(ctx context.Context, orgID, teamName string) (*AssignedTeam, error) {
Expand Down
Loading

0 comments on commit 66fafc9

Please sign in to comment.