Skip to content
This repository was archived by the owner on Apr 14, 2022. It is now read-only.

Commit 9f22717

Browse files
committed
Validate a nested variable by an argument type
1 parent 58f8962 commit 9f22717

File tree

8 files changed

+680
-424
lines changed

8 files changed

+680
-424
lines changed

Makefile

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ default:
1010
lint:
1111
luacheck graphql/*.lua \
1212
graphql/core/execute.lua \
13+
graphql/core/rules.lua \
1314
graphql/core/validate_variables.lua \
1415
graphql/convert_schema/*.lua \
1516
graphql/server/*.lua \

graphql/convert_schema/union.lua

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,8 +98,10 @@ local function create_union_types(avro_schema, opts)
9898
if type == 'null' then
9999
is_nullable = true
100100
else
101-
local variant_type = convert(type, {context = context})
102101
local box_field_name = type.name or avro_helpers.avro_type(type)
102+
table.insert(context.path, box_field_name)
103+
local variant_type = convert(type, {context = context})
104+
table.remove(context.path, #context.path)
103105
union_types[#union_types + 1] = box_type(variant_type,
104106
box_field_name, {
105107
gen_argument = gen_argument,

graphql/core/rules.lua

Lines changed: 175 additions & 98 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
1+
local yaml = require('yaml')
12
local path = (...):gsub('%.[^%.]+$', '')
23
local types = require(path .. '.types')
34
local util = require(path .. '.util')
4-
local schema = require(path .. '.schema')
55
local introspection = require(path .. '.introspection')
66
local query_util = require(path .. '.query_util')
7+
local graphql_utils = require('graphql.utils')
78

89
local function getParentField(context, name, count)
910
if introspection.fieldMap[name] then return introspection.fieldMap[name] end
@@ -162,7 +163,8 @@ function rules.unambiguousSelections(node, context)
162163

163164
validateField(key, fieldEntry)
164165
elseif selection.kind == 'inlineFragment' then
165-
local parentType = selection.typeCondition and context.schema:getType(selection.typeCondition.name.value) or parentType
166+
local parentType = selection.typeCondition and context.schema:getType(
167+
selection.typeCondition.name.value) or parentType
166168
validateSelectionSet(selection.selectionSet, parentType)
167169
elseif selection.kind == 'fragmentSpread' then
168170
local fragmentDefinition = context.fragmentMap[selection.name.value]
@@ -436,117 +438,192 @@ function rules.variablesAreDefined(node, context)
436438
end
437439
end
438440

439-
function rules.variableUsageAllowed(node, context)
440-
if context.currentOperation then
441-
local variableMap = {}
442-
for _, definition in ipairs(context.currentOperation.variableDefinitions or {}) do
443-
variableMap[definition.variable.name.value] = definition
444-
end
445-
446-
local arguments
447-
448-
if node.kind == 'field' then
449-
arguments = { [node.name.value] = node.arguments }
450-
elseif node.kind == 'fragmentSpread' then
451-
local seen = {}
452-
local function collectArguments(referencedNode)
453-
if referencedNode.kind == 'selectionSet' then
454-
for _, selection in ipairs(referencedNode.selections) do
455-
if not seen[selection] then
456-
seen[selection] = true
457-
collectArguments(selection)
458-
end
459-
end
460-
elseif referencedNode.kind == 'field' and referencedNode.arguments then
461-
local fieldName = referencedNode.name.value
462-
arguments[fieldName] = arguments[fieldName] or {}
463-
for _, argument in ipairs(referencedNode.arguments) do
464-
table.insert(arguments[fieldName], argument)
465-
end
466-
elseif referencedNode.kind == 'inlineFragment' then
467-
return collectArguments(referencedNode.selectionSet)
468-
elseif referencedNode.kind == 'fragmentSpread' then
469-
local fragment = context.fragmentMap[referencedNode.name.value]
470-
return fragment and collectArguments(fragment.selectionSet)
471-
end
441+
-- {{{ variableUsageAllowed
442+
443+
local function collectArguments(referencedNode, context, seen, arguments)
444+
if referencedNode.kind == 'selectionSet' then
445+
for _, selection in ipairs(referencedNode.selections) do
446+
if not seen[selection] then
447+
seen[selection] = true
448+
collectArguments(selection, context, seen, arguments)
472449
end
450+
end
451+
elseif referencedNode.kind == 'field' and referencedNode.arguments then
452+
local fieldName = referencedNode.name.value
453+
arguments[fieldName] = arguments[fieldName] or {}
454+
for _, argument in ipairs(referencedNode.arguments) do
455+
table.insert(arguments[fieldName], argument)
456+
end
457+
elseif referencedNode.kind == 'inlineFragment' then
458+
return collectArguments(referencedNode.selectionSet, context, seen,
459+
arguments)
460+
elseif referencedNode.kind == 'fragmentSpread' then
461+
local fragment = context.fragmentMap[referencedNode.name.value]
462+
return fragment and collectArguments(fragment.selectionSet, context, seen,
463+
arguments)
464+
end
465+
end
466+
467+
-- http://facebook.github.io/graphql/June2018/#AreTypesCompatible()
468+
local function isTypeSubTypeOf(subType, superType, context)
469+
if subType == superType then return true end
470+
471+
if superType.__type == 'NonNull' then
472+
if subType.__type == 'NonNull' then
473+
return isTypeSubTypeOf(subType.ofType, superType.ofType, context)
474+
end
475+
476+
return false
477+
elseif subType.__type == 'NonNull' then
478+
return isTypeSubTypeOf(subType.ofType, superType, context)
479+
end
480+
481+
if superType.__type == 'List' then
482+
if subType.__type == 'List' then
483+
return isTypeSubTypeOf(subType.ofType, superType.ofType, context)
484+
end
485+
486+
return false
487+
elseif subType.__type == 'List' then
488+
return false
489+
end
473490

474-
local fragment = context.fragmentMap[node.name.value]
475-
if fragment then
476-
arguments = {}
477-
collectArguments(fragment.selectionSet)
491+
if superType.__type == 'Scalar' and superType.subtype == 'InputUnion' then
492+
local types = superType.types
493+
for i = 1, #types do
494+
if types[i] == subType then
495+
return true
478496
end
479497
end
480498

481-
if not arguments then return end
499+
return false
500+
end
501+
502+
return false
503+
end
504+
505+
local function getTypeName(t)
506+
if t.name ~= nil then
507+
if t.name == 'Scalar' and t.subtype == 'InputMap' then
508+
return ('InputMap(%s)'):format(getTypeName(t.values))
509+
elseif t.name == 'Scalar' and t.subtype == 'InputUnion' then
510+
local typeNames = {}
511+
for _, child in ipairs(t.types) do
512+
table.insert(typeNames, getTypeName(child))
513+
end
514+
return ('InputUnion(%s)'):format(table.concat(typeNames, ','))
515+
end
516+
return t.name
517+
elseif t.__type == 'NonNull' then
518+
return ('NonNull(%s)'):format(getTypeName(t.ofType))
519+
elseif t.__type == 'List' then
520+
return ('List(%s)'):format(getTypeName(t.ofType))
521+
end
482522

483-
for field in pairs(arguments) do
484-
local parentField = getParentField(context, field)
485-
for i = 1, #arguments[field] do
486-
local argument = arguments[field][i]
487-
if argument.value.kind == 'variable' then
488-
local argumentType = parentField.arguments[argument.name.value]
523+
local orig_encode_use_tostring = yaml.cfg.encode_use_tostring
524+
local err = ('Internal error: unknown type:\n%s'):format(yaml.encode(t))
525+
yaml.cfg({encode_use_tostring = orig_encode_use_tostring})
526+
error(err)
527+
end
489528

490-
local variableName = argument.value.name.value
491-
local variableDefinition = variableMap[variableName]
492-
local hasDefault = variableDefinition.defaultValue ~= nil
529+
local function isVariableTypesValid(argument, argumentType, context,
530+
variableMap)
531+
if argument.value.kind == 'variable' then
532+
-- found a variable, check types compatibility
533+
local variableName = argument.value.name.value
534+
local variableDefinition = variableMap[variableName]
535+
local hasDefault = variableDefinition.defaultValue ~= nil
493536

494-
local variableType = query_util.typeFromAST(variableDefinition.type,
495-
context.schema)
537+
local variableType = query_util.typeFromAST(variableDefinition.type,
538+
context.schema)
496539

497-
if hasDefault and variableType.__type ~= 'NonNull' then
498-
variableType = types.nonNull(variableType)
499-
end
540+
if hasDefault and variableType.__type ~= 'NonNull' then
541+
variableType = types.nonNull(variableType)
542+
end
500543

501-
local function isTypeSubTypeOf(subType, superType)
502-
if subType == superType then return true end
503-
504-
if superType.__type == 'NonNull' then
505-
if subType.__type == 'NonNull' then
506-
return isTypeSubTypeOf(subType.ofType, superType.ofType)
507-
end
508-
509-
return false
510-
elseif subType.__type == 'NonNull' then
511-
return isTypeSubTypeOf(subType.ofType, superType)
512-
end
513-
514-
if superType.__type == 'List' then
515-
if subType.__type == 'List' then
516-
return isTypeSubTypeOf(subType.ofType, superType.ofType)
517-
end
518-
519-
return false
520-
elseif subType.__type == 'List' then
521-
return false
522-
end
523-
524-
if subType.__type ~= 'Object' then return false end
525-
526-
if superType.__type == 'Interface' then
527-
local implementors = context.schema:getImplementors(superType.name)
528-
return implementors and implementors[context.schema:getType(subType.name)]
529-
elseif superType.__type == 'Union' then
530-
local types = superType.types
531-
for i = 1, #types do
532-
if types[i] == subType then
533-
return true
534-
end
535-
end
536-
537-
return false
538-
end
539-
540-
return false
541-
end
544+
if not isTypeSubTypeOf(variableType, argumentType, context) then
545+
return false, ('Variable "%s" type mismatch: the variable type "%s" ' ..
546+
'is not compatible with the argument type "%s"'):format(variableName,
547+
getTypeName(variableType), getTypeName(argumentType))
548+
end
549+
elseif argument.value.kind == 'inputObject' then
550+
-- find variables deeper
551+
for _, child in ipairs(argument.value.values) do
552+
local isInputObject = argumentType.__type == 'InputObject'
553+
local isInputMap = argumentType.__type == 'Scalar' and
554+
argumentType.subtype == 'InputMap'
555+
local isInputUnion = argumentType.__type == 'Scalar' and
556+
argumentType.subtype == 'InputUnion'
557+
558+
if isInputObject then
559+
local childArgumentType = argumentType.fields[child.name].kind
560+
local ok, err = isVariableTypesValid(child, childArgumentType, context,
561+
variableMap)
562+
if not ok then return false, err end
563+
elseif isInputMap then
564+
local childArgumentType = argumentType.values
565+
local ok, err = isVariableTypesValid(child, childArgumentType, context,
566+
variableMap)
567+
if not ok then return false, err end
568+
elseif isInputUnion then
569+
local has_ok = false
570+
local first_err
571+
572+
for _, childArgumentType in ipairs(argumentType.types) do
573+
local ok, err = isVariableTypesValid(child,
574+
childArgumentType, context, variableMap)
575+
has_ok = has_ok or ok
576+
first_err = first_err or graphql_utils.strip_error(err)
577+
if ok then break end
578+
end
542579

543-
if not isTypeSubTypeOf(variableType, argumentType) then
544-
error('Variable type mismatch')
545-
end
580+
if not has_ok then
581+
return false, first_err
546582
end
547583
end
548584
end
549585
end
586+
return true
587+
end
588+
589+
function rules.variableUsageAllowed(node, context)
590+
if not context.currentOperation then return end
591+
592+
local variableMap = {}
593+
local variableDefinitions = context.currentOperation.variableDefinitions
594+
for _, definition in ipairs(variableDefinitions or {}) do
595+
variableMap[definition.variable.name.value] = definition
596+
end
597+
598+
local arguments
599+
600+
if node.kind == 'field' then
601+
arguments = { [node.name.value] = node.arguments }
602+
elseif node.kind == 'fragmentSpread' then
603+
local seen = {}
604+
local fragment = context.fragmentMap[node.name.value]
605+
if fragment then
606+
arguments = {}
607+
collectArguments(fragment.selectionSet, context, seen, arguments)
608+
end
609+
end
610+
611+
if not arguments then return end
612+
613+
for field in pairs(arguments) do
614+
local parentField = getParentField(context, field)
615+
for i = 1, #arguments[field] do
616+
local argument = arguments[field][i]
617+
local argumentType = parentField.arguments[argument.name.value]
618+
local ok, err = isVariableTypesValid(argument, argumentType, context,
619+
variableMap)
620+
if not ok then
621+
error(err)
622+
end
623+
end
624+
end
550625
end
551626

627+
-- }}}
628+
552629
return rules

graphql/core/schema.lua

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,8 @@ function schema:generateTypeMap(node)
5050
node.fields = type(node.fields) == 'function' and node.fields() or node.fields
5151
self.typeMap[node.name] = node
5252

53-
if node.__type == 'Union' then
53+
if node.__type == 'Union' or (node.__type == 'Scalar' and
54+
node.subtype == 'InputUnion') then
5455
for _, type in ipairs(node.types) do
5556
self:generateTypeMap(type)
5657
end
@@ -77,6 +78,10 @@ function schema:generateTypeMap(node)
7778
self:generateTypeMap(field.kind)
7879
end
7980
end
81+
82+
if node.type == 'Scalar' and node.subtype == 'InputMap' then
83+
self:generateTypeMap(node.values)
84+
end
8085
end
8186

8287
function schema:generateDirectiveMap()

graphql/core/types.lua

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -249,6 +249,7 @@ function types.inputUnion(config)
249249
__type = 'Scalar',
250250
subtype = 'InputUnion',
251251
name = config.name,
252+
types = config.types,
252253
serialize = function(value) return value end,
253254
parseValue = function(value) return value end,
254255
parseLiteral = function(node)

0 commit comments

Comments
 (0)