Skip to content

Commit

Permalink
remove AddNodeType
Browse files Browse the repository at this point in the history
  • Loading branch information
ccbrown committed Oct 6, 2023
1 parent e39541b commit cd7bcd9
Show file tree
Hide file tree
Showing 12 changed files with 216 additions and 293 deletions.
51 changes: 0 additions & 51 deletions api.go
Original file line number Diff line number Diff line change
Expand Up @@ -268,54 +268,3 @@ func isNil(v interface{}) bool {
rv := reflect.ValueOf(v)
return (rv.Kind() == reflect.Ptr || rv.Kind() == reflect.Interface) && rv.IsNil()
}

func (api *API) resolveNodeByGlobalId(ctx context.Context, id string) (interface{}, error) {
typeId, modelId := api.config.DeserializeNodeId(id)
nodeType, ok := api.config.nodeTypesById[typeId]
if !ok {
return nil, nil
}
return api.resolveNodeById(ctx, nodeType, modelId)
}

func (api *API) resolveNodesByGlobalIds(ctx context.Context, ids []string) ([]interface{}, error) {
modelIds := map[int][]interface{}{}
for _, id := range ids {
typeId, modelId := api.config.DeserializeNodeId(id)
modelIds[typeId] = append(modelIds[typeId], modelId)
}
var ret []interface{}
for typeId, modelIds := range modelIds {
nodeType, ok := api.config.nodeTypesById[typeId]
if !ok {
continue
}
ids := reflect.MakeSlice(reflect.SliceOf(reflect.TypeOf(modelIds[0])), len(modelIds), len(modelIds))
for i, modelId := range modelIds {
ids.Index(i).Set(reflect.ValueOf(modelId))
}
nodes, err := nodeType.GetByIds(ctx, ids.Interface())
if !isNil(err) {
return nil, err
}
nodesValue := reflect.ValueOf(nodes)
for i := 0; i < nodesValue.Len(); i++ {
ret = append(ret, nodesValue.Index(i).Interface())
}
}
return ret, nil
}

func (api *API) resolveNodeById(ctx context.Context, nodeType *NodeType, modelId interface{}) (interface{}, error) {
ids := reflect.MakeSlice(reflect.SliceOf(reflect.TypeOf(modelId)), 1, 1)
ids.Index(0).Set(reflect.ValueOf(modelId))
nodes, err := nodeType.GetByIds(ctx, ids.Interface())
if !isNil(err) {
return nil, err
}
nodesValue := reflect.ValueOf(nodes)
if nodesValue.Len() < 1 {
return nil, nil
}
return nodesValue.Index(0).Interface(), nil
}
44 changes: 20 additions & 24 deletions api_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import (
"io/ioutil"
"net/http"
"net/http/httptest"
"reflect"
"strings"
"testing"

Expand Down Expand Up @@ -108,39 +107,36 @@ func TestBatch(t *testing.T) {
}

func TestNodes(t *testing.T) {
const nodeTypeId = 10

testCfg := Config{
SerializeNodeId: func(typeId int, id interface{}) string {
assert.Equal(t, nodeTypeId, typeId)
return id.(string)
},
DeserializeNodeId: func(id string) (int, interface{}) {
return nodeTypeId, id
},
}

type node struct {
Id string
}

testCfg.AddNodeType(&NodeType{
Id: nodeTypeId,
Name: "TestNode",
Model: reflect.TypeOf(node{}),
GetByIds: func(ctx context.Context, ids interface{}) (interface{}, error) {
var ret []*node
for _, id := range ids.([]string) {
testCfg := Config{
ResolveNodesByGlobalIds: func(ctx context.Context, ids []string) ([]interface{}, error) {
var ret []interface{}
for _, id := range ids {
if id == "a" || id == "b" {
ret = append(ret, &node{
Id: id,
})
ret = append(ret, &node{Id: id})
}
}
return ret, nil
},
}

testCfg.AddNamedType(&graphql.ObjectType{
Name: "TestNode",
Fields: map[string]*graphql.FieldDefinition{
"id": OwnID("Id"),
"id": {
Type: graphql.NewNonNullType(graphql.IDType),
Resolve: func(ctx graphql.FieldContext) (interface{}, error) {
return ctx.Object.(*node).Id, nil
},
},
},
ImplementedInterfaces: []*graphql.InterfaceType{testCfg.NodeInterface()},
IsTypeOf: func(value interface{}) bool {
_, ok := value.(*node)
return ok
},
})

Expand Down
75 changes: 22 additions & 53 deletions config.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import (
"context"
"encoding/json"
"net/http"
"reflect"
"sync"

"github.com/sirupsen/logrus"
Expand All @@ -16,10 +15,13 @@ import (
type Config struct {
Logger logrus.FieldLogger
WebSocketOriginCheck func(r *http.Request) bool
SerializeNodeId func(typeId int, id interface{}) string
DeserializeNodeId func(string) (typeId int, id interface{})

// If given, these fields will be added to the Node interface.
AdditionalNodeFields map[string]*graphql.FieldDefinition

// Invoked to get nodes by their global ids.
ResolveNodesByGlobalIds func(ctx context.Context, ids []string) ([]interface{}, error)

// If given, Apollo persisted queries are supported by the API:
// https://www.apollographql.com/docs/react/api/link/persisted-queries/
PersistedQueryStorage PersistedQueryStorage
Expand Down Expand Up @@ -49,23 +51,15 @@ type Config struct {
// documentation.
PreprocessGraphQLSchemaDefinition func(schema *graphql.SchemaDefinition) error

initOnce sync.Once
nodeObjectTypesByName map[string]*graphql.ObjectType
nodeTypesByModel map[reflect.Type]*NodeType
nodeTypesById map[int]*NodeType
nodeTypesByObjectType map[*graphql.ObjectType]*NodeType
nodeInterface *graphql.InterfaceType
query *graphql.ObjectType
mutation *graphql.ObjectType
subscription *graphql.ObjectType
initOnce sync.Once
nodeInterface *graphql.InterfaceType
query *graphql.ObjectType
mutation *graphql.ObjectType
subscription *graphql.ObjectType
}

func (cfg *Config) init() {
cfg.initOnce.Do(func() {
cfg.nodeObjectTypesByName = make(map[string]*graphql.ObjectType)
cfg.nodeTypesByModel = make(map[reflect.Type]*NodeType)
cfg.nodeTypesById = make(map[int]*NodeType)
cfg.nodeTypesByObjectType = make(map[*graphql.ObjectType]*NodeType)
if cfg.AdditionalTypes == nil {
cfg.AdditionalTypes = make(map[string]graphql.NamedType)
}
Expand Down Expand Up @@ -95,7 +89,11 @@ func (cfg *Config) init() {
Cost: graphql.FieldResolverCost(1),
Resolve: func(ctx graphql.FieldContext) (interface{}, error) {
// TODO: batching?
return ctxAPI(ctx.Context).resolveNodeByGlobalId(ctx.Context, ctx.Arguments["id"].(string))
nodes, err := ctxAPI(ctx.Context).config.ResolveNodesByGlobalIds(ctx.Context, []string{ctx.Arguments["id"].(string)})
if err != nil || len(nodes) == 0 {
return nil, err
}
return nodes[0], nil
},
},
"nodes": {
Expand All @@ -118,7 +116,7 @@ func (cfg *Config) init() {
for _, id := range ctx.Arguments["ids"].([]interface{}) {
ids = append(ids, id.(string))
}
return ctxAPI(ctx.Context).resolveNodesByGlobalIds(ctx.Context, ids)
return ctxAPI(ctx.Context).config.ResolveNodesByGlobalIds(ctx.Context, ids)
},
},
},
Expand Down Expand Up @@ -158,48 +156,19 @@ func (cfg *Config) graphqlSchema() (*graphql.Schema, error) {
return graphql.NewSchema(def)
}

// NodeObjectType returns the object type for a node type previously added via AddNodeType.
func (cfg *Config) NodeObjectType(name string) *graphql.ObjectType {
return cfg.nodeObjectTypesByName[name]
}

// AddNodeType registers the given node type and returned the object type created for the node.
func (cfg *Config) AddNodeType(t *NodeType) *graphql.ObjectType {
cfg.init()

model := normalizeModelType(t.Model)
if _, ok := cfg.nodeTypesByModel[model]; ok {
panic("node type already exists for model")
}
cfg.nodeTypesByModel[model] = t

if _, ok := cfg.nodeTypesById[t.Id]; ok {
panic("node type already exists for type id")
}
cfg.nodeTypesById[t.Id] = t

objectType := &graphql.ObjectType{
Name: t.Name,
Fields: t.Fields,
ImplementedInterfaces: []*graphql.InterfaceType{cfg.nodeInterface},
IsTypeOf: func(v interface{}) bool {
return normalizeModelType(reflect.TypeOf(v)) == model
},
}
cfg.AdditionalTypes[t.Name] = objectType
cfg.nodeTypesByObjectType[objectType] = t
cfg.nodeObjectTypesByName[t.Name] = objectType

return objectType
}

// AddNamedType adds a named type to the schema. This is generally only required for interface
// implementations that aren't explicitly referenced elsewhere in the schema.
func (cfg *Config) AddNamedType(t graphql.NamedType) {
cfg.init()
cfg.AdditionalTypes[t.TypeName()] = t
}

// NodeInterface returns the node interface.
func (cfg *Config) NodeInterface() *graphql.InterfaceType {
cfg.init()
return cfg.nodeInterface
}

// MutationType returns the root mutation type.
func (cfg *Config) MutationType() *graphql.ObjectType {
cfg.init()
Expand Down
67 changes: 53 additions & 14 deletions examples/chat/api/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,28 +12,67 @@ import (
"github.com/ccbrown/api-fu/examples/chat/model"
)

func DeserializeId(id string) model.Id {
const (
UserTypeId = 1
ChannelTypeId = 2
MessageTypeId = 3
)

func SerializeNodeId(typeId int, id model.Id) string {
buf := make([]byte, binary.MaxVarintLen64)
n := binary.PutVarint(buf, int64(typeId))
return base64.RawURLEncoding.EncodeToString(append(buf[:n], id...))
}

func DeserializeNodeId(id string) (int, model.Id) {
if buf, err := base64.RawURLEncoding.DecodeString(id); err == nil {
if _, n := binary.Varint(buf); n > 0 {
return model.Id(buf[n:])
if typeId, n := binary.Varint(buf); n > 0 {
return int(typeId), model.Id(buf[n:])
}
}
return nil
return 0, nil
}

var fuCfg = apifu.Config{
SerializeNodeId: func(typeId int, id interface{}) string {
buf := make([]byte, binary.MaxVarintLen64)
n := binary.PutVarint(buf, int64(typeId))
return base64.RawURLEncoding.EncodeToString(append(buf[:n], id.(model.Id)...))
},
DeserializeNodeId: func(id string) (int, interface{}) {
if buf, err := base64.RawURLEncoding.DecodeString(id); err == nil {
if typeId, n := binary.Varint(buf); n > 0 {
return int(typeId), model.Id(buf[n:])
ResolveNodesByGlobalIds: func(ctx context.Context, ids []string) ([]interface{}, error) {
var userIds []model.Id
var channelIds []model.Id
var messageIds []model.Id
for _, id := range ids {
typeId, id := DeserializeNodeId(id)
switch typeId {
case UserTypeId:
userIds = append(userIds, id)
case ChannelTypeId:
channelIds = append(channelIds, id)
case MessageTypeId:
messageIds = append(messageIds, id)
}
}
return 0, nil
sess := ctxSession(ctx)
channels, err := sess.GetChannelsByIds(channelIds...)
if err != nil {
return nil, err
}
messages, err := sess.GetMessagesByIds(messageIds...)
if err != nil {
return nil, err
}
users, err := sess.GetUsersByIds(userIds...)
if err != nil {
return nil, err
}
ret := make([]interface{}, 0, len(ids))
for _, channel := range channels {
ret = append(ret, channel)
}
for _, message := range messages {
ret = append(ret, message)
}
for _, user := range users {
ret = append(ret, user)
}
return ret, nil
},
}

Expand Down
33 changes: 23 additions & 10 deletions examples/chat/api/channel.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package api

import (
"context"
"reflect"
"strings"
"time"
Expand All @@ -11,14 +10,14 @@ import (
"github.com/ccbrown/api-fu/graphql"
)

var channelType = fuCfg.AddNodeType(&apifu.NodeType{
Id: 2,
Name: "Channel",
Model: reflect.TypeOf(model.Channel{}),
GetByIds: func(ctx context.Context, ids interface{}) (interface{}, error) {
return ctxSession(ctx).GetChannelsByIds(ids.([]model.Id)...)
var channelType = &graphql.ObjectType{
Name: "Channel",
ImplementedInterfaces: []*graphql.InterfaceType{fuCfg.NodeInterface()},
IsTypeOf: func(value interface{}) bool {
_, ok := value.(*model.Channel)
return ok
},
})
}

func init() {
type messageCursor struct {
Expand All @@ -27,10 +26,24 @@ func init() {
}

channelType.Fields = map[string]*graphql.FieldDefinition{
"id": apifu.OwnID("Id"),
"id": {
Type: graphql.NewNonNullType(graphql.IDType),
Resolve: func(ctx graphql.FieldContext) (interface{}, error) {
return SerializeNodeId(ChannelTypeId, ctx.Object.(*model.Channel).Id), nil
},
},
"name": apifu.NonNull(graphql.StringType, "Name"),
"creationTime": apifu.NonNull(apifu.DateTimeType, "CreationTime"),
"creator": apifu.Node(userType, "CreatorUserId"),
"creator": {
Type: userType,
Resolve: func(ctx graphql.FieldContext) (interface{}, error) {
users, err := ctxSession(ctx.Context).GetUsersByIds(ctx.Object.(*model.Channel).CreatorUserId)
if err != nil || len(users) == 0 {
return nil, err
}
return users[0], nil
},
},
"messagesConnection": apifu.TimeBasedConnection(&apifu.TimeBasedConnectionConfig{
NamePrefix: "ChannelMessages",
EdgeCursor: func(edge interface{}) apifu.TimeBasedCursor {
Expand Down
Loading

0 comments on commit cd7bcd9

Please sign in to comment.