Skip to content

Commit

Permalink
allow graphql types and fields to be gated behind feature flags (#69)
Browse files Browse the repository at this point in the history
* add graphql features and required features for field

* lots o tests

* stricter checking of object/interface fields

* support feature requirements for types

* add tests
  • Loading branch information
ccbrown authored Mar 16, 2024
1 parent c289cff commit 66fc90e
Show file tree
Hide file tree
Showing 41 changed files with 410 additions and 83 deletions.
5 changes: 4 additions & 1 deletion api.go
Original file line number Diff line number Diff line change
Expand Up @@ -234,10 +234,13 @@ func (api *API) ServeGraphQL(w http.ResponseWriter, r *http.Request) {
}
req.Schema = api.schema
req.IdleHandler = apiRequest.IdleHandler
if api.config.Features != nil {
req.Features = api.config.Features(ctx)
}

execute := func(req *graphql.Request) *graphql.Response {
var info RequestInfo
if doc, errs := graphql.ParseAndValidate(req.Query, req.Schema, req.ValidateCost(-1, &info.Cost, api.config.DefaultFieldCost)); len(errs) > 0 {
if doc, errs := graphql.ParseAndValidate(req.Query, req.Schema, req.Features, req.ValidateCost(-1, &info.Cost, api.config.DefaultFieldCost)); len(errs) > 0 {
return &graphql.Response{
Errors: errs,
}
Expand Down
72 changes: 71 additions & 1 deletion api_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,20 @@ import (
)

func executeGraphQL(t *testing.T, api *API, query string) *http.Response {
return executeGraphQLWithFeatures(t, api, query, nil)
}

const featuresContextKey = "features"

func featuresFromContext(ctx context.Context) graphql.FeatureSet {
features, _ := ctx.Value(featuresContextKey).(graphql.FeatureSet)
return features
}

func executeGraphQLWithFeatures(t *testing.T, api *API, query string, features []string) *http.Response {
w := httptest.NewRecorder()
r, err := http.NewRequest("POST", "", strings.NewReader(query))
ctx := context.WithValue(context.Background(), featuresContextKey, graphql.NewFeatureSet(features...))
r, err := http.NewRequestWithContext(ctx, "POST", "", strings.NewReader(query))
r.Header.Set("Content-Type", "application/graphql")
require.NoError(t, err)
api.ServeGraphQL(w, r)
Expand Down Expand Up @@ -214,3 +226,61 @@ func TestMutation(t *testing.T) {
require.NoError(t, err)
assert.JSONEq(t, `{"data":{"mut":true}}`, string(body))
}

func TestFeatures(t *testing.T) {
var testCfg Config
testCfg.Features = featuresFromContext

testCfg.AddQueryField("foo", &graphql.FieldDefinition{
Type: graphql.BooleanType,
Resolve: func(ctx graphql.FieldContext) (interface{}, error) {
return true, nil
},
})

testCfg.AddQueryField("bar", &graphql.FieldDefinition{
Type: graphql.BooleanType,
RequiredFeatures: graphql.NewFeatureSet("bar"),
Resolve: func(ctx graphql.FieldContext) (interface{}, error) {
return true, nil
},
})

api, err := NewAPI(&testCfg)
require.NoError(t, err)

t.Run("NoFeatures", func(t *testing.T) {
resp := executeGraphQL(t, api, `{
foo
}`)
require.Equal(t, http.StatusOK, resp.StatusCode)

body, err := ioutil.ReadAll(resp.Body)
require.NoError(t, err)
assert.JSONEq(t, `{"data":{"foo":true}}`, string(body))
})

t.Run("NoFeatures_Error", func(t *testing.T) {
resp := executeGraphQL(t, api, `{
foo
bar
}`)
require.Equal(t, http.StatusOK, resp.StatusCode)

body, err := ioutil.ReadAll(resp.Body)
require.NoError(t, err)
assert.JSONEq(t, `{"errors":[{"locations":[{"column":4,"line":3}],"message":"Validation error: field bar does not exist on Query"}]}`, string(body))
})

t.Run("BarFeature", func(t *testing.T) {
resp := executeGraphQLWithFeatures(t, api, `{
foo
bar
}`, []string{"bar"})
require.Equal(t, http.StatusOK, resp.StatusCode)

body, err := ioutil.ReadAll(resp.Body)
require.NoError(t, err)
assert.JSONEq(t, `{"data":{"foo":true,"bar":true}}`, string(body))
})
}
2 changes: 1 addition & 1 deletion cmd/gql-client-gen/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ func generateTypeDef(name, original string) string {

func (s *generateState) processQuery(q string) []error {
var ret []error
doc, errs := graphql.ParseAndValidate(q, s.schema)
doc, errs := graphql.ParseAndValidate(q, s.schema, nil)
if len(errs) > 0 {
for _, err := range errs {
ret = append(ret, err)
Expand Down
3 changes: 3 additions & 0 deletions config.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,9 @@ type Config struct {
// documentation.
PreprocessGraphQLSchemaDefinition func(schema *graphql.SchemaDefinition) error

// If given, this function will be invoked to get the feature set for a request.
Features func(ctx context.Context) graphql.FeatureSet

initOnce sync.Once
nodeInterface *graphql.InterfaceType
query *graphql.ObjectType
Expand Down
15 changes: 10 additions & 5 deletions graphql/executor/executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ type Request struct {
Schema *schema.Schema
OperationName string
VariableValues map[string]any
Features schema.FeatureSet
InitialValue any
IdleHandler func()
}
Expand Down Expand Up @@ -72,6 +73,7 @@ type executor struct {
Schema *schema.Schema
FragmentDefinitions map[string]*ast.FragmentDefinition
VariableValues map[string]any
Features schema.FeatureSet
Errors []*Error
Operation *ast.OperationDefinition
IdleHandler func()
Expand All @@ -89,7 +91,7 @@ func newExecutor(ctx context.Context, r *Request) (*executor, *Error) {
if err != nil {
return nil, err
}
coercedVariableValues, err := coerceVariableValues(r.Schema, operation, r.VariableValues)
coercedVariableValues, err := coerceVariableValues(r.Schema, r.Features, operation, r.VariableValues)
if err != nil {
return nil, err
}
Expand All @@ -99,6 +101,7 @@ func newExecutor(ctx context.Context, r *Request) (*executor, *Error) {
Schema: r.Schema,
FragmentDefinitions: map[string]*ast.FragmentDefinition{},
VariableValues: coercedVariableValues,
Features: r.Features,
Operation: operation,
IdleHandler: r.IdleHandler,
GroupedFieldSetCache: map[string]*GroupedFieldSet{},
Expand Down Expand Up @@ -162,7 +165,7 @@ func (e *executor) subscribe(initialValue any) (any, *Error) {
fields := item.Fields
field := fields[0]
fieldName := field.Name.Name
fieldDef := subscriptionType.Fields[fieldName]
fieldDef := subscriptionType.GetField(fieldName, e.Features)
if fieldDef == nil {
return nil, newError(field, "Undefined root subscription field.")
}
Expand All @@ -175,6 +178,7 @@ func (e *executor) subscribe(initialValue any) (any, *Error) {
Context: e.Context,
Schema: e.Schema,
Object: initialValue,
Features: e.Features,
Arguments: argumentValues,
IsSubscribe: true,
})
Expand Down Expand Up @@ -248,7 +252,7 @@ func (e *executor) executeSelections(selections []ast.Selection, objectType *sch
continue
}

fieldDef := objectType.Fields[fieldName]
fieldDef := objectType.GetField(fieldName, e.Features)
if fieldDef == nil && objectType == e.Schema.QueryType() {
fieldDef = introspection.MetaFields[fieldName]
}
Expand Down Expand Up @@ -319,6 +323,7 @@ func (e *executor) executeField(objectValue any, fields []*ast.Field, fieldDef *
Context: e.Context,
Schema: e.Schema,
Object: objectValue,
Features: e.Features,
Arguments: argumentValues,
})
if !isNil(err) {
Expand Down Expand Up @@ -606,8 +611,8 @@ func schemaType(t ast.Type, s *schema.Schema) schema.Type {
return nil
}

func coerceVariableValues(s *schema.Schema, operation *ast.OperationDefinition, variableValues map[string]any) (map[string]any, *Error) {
ret, err := validator.CoerceVariableValues(s, operation, variableValues)
func coerceVariableValues(s *schema.Schema, features schema.FeatureSet, operation *ast.OperationDefinition, variableValues map[string]any) (map[string]any, *Error) {
ret, err := validator.CoerceVariableValues(s, features, operation, variableValues)
return ret, newErrorWithValidatorError(err)
}

Expand Down
10 changes: 5 additions & 5 deletions graphql/executor/executor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ func TestSubscribe(t *testing.T) {
require.NoError(t, err)
doc, parseErrs := parser.ParseDocument([]byte(`subscription {int}`))
require.Empty(t, parseErrs)
require.Empty(t, validator.ValidateDocument(doc, s))
require.Empty(t, validator.ValidateDocument(doc, s, nil))

assert.True(t, IsSubscription(doc, ""))

Expand Down Expand Up @@ -262,7 +262,7 @@ func TestExecuteRequest(t *testing.T) {
t.Run("IntrospectionQuery", func(t *testing.T) {
parsed, parseErrs := parser.ParseDocument(introspection.Query)
require.Empty(t, parseErrs)
require.Empty(t, validator.ValidateDocument(parsed, s))
require.Empty(t, validator.ValidateDocument(parsed, s, nil))
_, errs := ExecuteRequest(context.Background(), &Request{
Document: parsed,
Schema: s,
Expand Down Expand Up @@ -419,7 +419,7 @@ func TestExecuteRequest(t *testing.T) {
t.Run(name, func(t *testing.T) {
parsed, parseErrs := parser.ParseDocument([]byte(tc.Document))
require.Empty(t, parseErrs)
require.Empty(t, validator.ValidateDocument(parsed, s))
require.Empty(t, validator.ValidateDocument(parsed, s, nil))
data, errs := ExecuteRequest(context.Background(), &Request{
Document: parsed,
Schema: s,
Expand Down Expand Up @@ -534,7 +534,7 @@ func BenchmarkExecuteRequest(b *testing.B) {
}
}`))
require.Empty(b, parseErrs)
require.Empty(b, validator.ValidateDocument(doc, s))
require.Empty(b, validator.ValidateDocument(doc, s, nil))

r := &Request{
Document: doc,
Expand Down Expand Up @@ -585,7 +585,7 @@ func TestContextCancelation(t *testing.T) {
}
}`))
require.Empty(t, parseErrs)
require.Empty(t, validator.ValidateDocument(doc, s))
require.Empty(t, validator.ValidateDocument(doc, s, nil))

r := &Request{
Document: doc,
Expand Down
18 changes: 14 additions & 4 deletions graphql/graphql.go
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,14 @@ type Schema = schema.Schema
// SchemaDefinition defines a GraphQL schema.
type SchemaDefinition = schema.SchemaDefinition

// FeatureSet represents a set of features.
type FeatureSet = schema.FeatureSet

// NewFeatureSet creates a new feature set with the given features.
func NewFeatureSet(features ...string) FeatureSet {
return schema.NewFeatureSet(features...)
}

// NewSchema validates a schema definition and builds a Schema from it.
func NewSchema(def *SchemaDefinition) (*Schema, error) {
return schema.New(def)
Expand All @@ -152,6 +160,7 @@ type Request struct {
Schema *Schema
OperationName string
VariableValues map[string]interface{}
Features FeatureSet
Extensions map[string]interface{}
InitialValue interface{}
IdleHandler func()
Expand All @@ -171,6 +180,7 @@ func (r *Request) executorRequest(doc *ast.Document) *executor.Request {
Schema: r.Schema,
OperationName: r.OperationName,
VariableValues: r.VariableValues,
Features: r.Features,
InitialValue: r.InitialValue,
IdleHandler: r.IdleHandler,
}
Expand Down Expand Up @@ -276,7 +286,7 @@ func IsSubscription(doc *ast.Document, operationName string) bool {
}

// ParseAndValidate parses and validates a query.
func ParseAndValidate(query string, schema *Schema, additionalRules ...ValidatorRule) (*ast.Document, []*Error) {
func ParseAndValidate(query string, schema *Schema, features schema.FeatureSet, additionalRules ...ValidatorRule) (*ast.Document, []*Error) {
var errors []*Error
parsed, parseErrs := parser.ParseDocument([]byte(query))
if len(parseErrs) > 0 {
Expand All @@ -293,7 +303,7 @@ func ParseAndValidate(query string, schema *Schema, additionalRules ...Validator
}
return nil, errors
}
if validationErrs := validator.ValidateDocument(parsed, schema, additionalRules...); len(validationErrs) > 0 {
if validationErrs := validator.ValidateDocument(parsed, schema, features, additionalRules...); len(validationErrs) > 0 {
for _, err := range validationErrs {
locations := make([]Location, len(err.Locations))
for i, loc := range err.Locations {
Expand Down Expand Up @@ -334,7 +344,7 @@ func Subscribe(r *Request) (interface{}, []*Error) {
doc := r.Document
if doc == nil {
var errors []*Error
doc, errors = ParseAndValidate(r.Query, r.Schema)
doc, errors = ParseAndValidate(r.Query, r.Schema, r.Features)
if len(errors) > 0 {
return nil, errors
}
Expand All @@ -354,7 +364,7 @@ func Execute(r *Request) *Response {
doc := r.Document
if doc == nil {
var errors []*Error
doc, errors = ParseAndValidate(r.Query, r.Schema)
doc, errors = ParseAndValidate(r.Query, r.Schema, r.Features)
if len(errors) > 0 {
return &Response{
Errors: errors,
Expand Down
7 changes: 7 additions & 0 deletions graphql/schema/enum_type.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@ type EnumType struct {
Description string
Directives []*Directive
Values map[string]*EnumValueDefinition

// This type is only available for introspection and use when the given features are enabled.
RequiredFeatures FeatureSet
}

type EnumValueDefinition struct {
Expand Down Expand Up @@ -40,6 +43,10 @@ func (t *EnumType) IsSameType(other Type) bool {
return t == other
}

func (t *EnumType) TypeRequiredFeatures() FeatureSet {
return t.RequiredFeatures
}

func (t *EnumType) TypeName() string {
return t.Name
}
Expand Down
25 changes: 25 additions & 0 deletions graphql/schema/feature_set.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
package schema

type FeatureSet map[string]struct{}

func NewFeatureSet(features ...string) FeatureSet {
fs := make(FeatureSet, len(features))
for _, feature := range features {
fs[feature] = struct{}{}
}
return fs
}

func (s FeatureSet) Has(feature string) bool {
_, ok := s[feature]
return ok
}

func (s FeatureSet) IsSubsetOf(other FeatureSet) bool {
for feature := range s {
if _, ok := other[feature]; !ok {
return false
}
}
return true
}
32 changes: 32 additions & 0 deletions graphql/schema/feature_set_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
package schema

import (
"testing"

"github.com/stretchr/testify/assert"
)

func TestFeatureSet(t *testing.T) {
s := NewFeatureSet("a", "b", "c")
assert.True(t, s.Has("a"))
assert.True(t, s.Has("b"))
assert.True(t, s.Has("c"))
assert.False(t, s.Has("d"))

s2 := NewFeatureSet("a", "b")
assert.True(t, s2.IsSubsetOf(s))
assert.False(t, s.IsSubsetOf(s2))
}

func TestFeatureSet_Nil(t *testing.T) {
var s FeatureSet
assert.False(t, s.Has("a"))

s2 := NewFeatureSet("a", "b")
assert.True(t, s.IsSubsetOf(s2))
assert.False(t, s2.IsSubsetOf(s))

var s3 FeatureSet
assert.True(t, s.IsSubsetOf(s3))
assert.True(t, s3.IsSubsetOf(s))
}
Loading

0 comments on commit 66fc90e

Please sign in to comment.