@@ -4,6 +4,7 @@ local util = require(path .. '.util')
4
4
local schema = require (path .. ' .schema' )
5
5
local introspection = require (path .. ' .introspection' )
6
6
local query_util = require (path .. ' .query_util' )
7
+ local graphql_utils = require (' graphql.utils' )
7
8
8
9
local function getParentField (context , name , count )
9
10
if introspection .fieldMap [name ] then return introspection .fieldMap [name ] end
@@ -484,66 +485,100 @@ function rules.variableUsageAllowed(node, context)
484
485
local parentField = getParentField (context , field )
485
486
for i = 1 , # arguments [field ] do
486
487
local argument = arguments [field ][i ]
487
- if argument .value .kind == ' variable' then
488
- local argumentType = parentField .arguments [argument .name .value ]
488
+ local argumentType = parentField .arguments [argument .name .value ]
489
+ local function recursiveValidateVariableType (argument , argumentType )
490
+ if argument .value .kind == ' variable' then
491
+ local variableName = argument .value .name .value
492
+ local variableDefinition = variableMap [variableName ]
493
+ local hasDefault = variableDefinition .defaultValue ~= nil
494
+
495
+ local variableType = query_util .typeFromAST (variableDefinition .type ,
496
+ context .schema )
497
+
498
+ if hasDefault and variableType .__type ~= ' NonNull' then
499
+ variableType = types .nonNull (variableType )
500
+ end
489
501
490
- local variableName = argument .value .name .value
491
- local variableDefinition = variableMap [variableName ]
492
- local hasDefault = variableDefinition .defaultValue ~= nil
502
+ local function isTypeSubTypeOf (subType , superType )
503
+ if subType == superType then return true end
493
504
494
- local variableType = query_util .typeFromAST (variableDefinition .type ,
495
- context .schema )
505
+ if superType .__type == ' NonNull' then
506
+ if subType .__type == ' NonNull' then
507
+ return isTypeSubTypeOf (subType .ofType , superType .ofType )
508
+ end
496
509
497
- if hasDefault and variableType .__type ~= ' NonNull' then
498
- variableType = types .nonNull (variableType )
499
- end
510
+ return false
511
+ elseif subType .__type == ' NonNull' then
512
+ return isTypeSubTypeOf (subType .ofType , superType )
513
+ end
514
+
515
+ if superType .__type == ' List' then
516
+ if subType .__type == ' List' then
517
+ return isTypeSubTypeOf (subType .ofType , superType .ofType )
518
+ end
500
519
501
- local function isTypeSubTypeOf (subType , superType )
502
- if subType == superType then return true end
520
+ return false
521
+ elseif subType .__type == ' List' then
522
+ return false
523
+ end
503
524
504
- if superType .__type == ' NonNull' then
505
- if subType .__type == ' NonNull' then
506
- return isTypeSubTypeOf (subType .ofType , superType .ofType )
525
+ -- XXX: InputMap, ...; all named types must be allowed
526
+ if subType .__type ~= ' Object' and
527
+ subType .__type ~= ' InputObject' then
528
+ return false
507
529
end
508
530
509
- return false
510
- elseif subType .__type == ' NonNull' then
511
- return isTypeSubTypeOf (subType .ofType , superType )
512
- end
531
+ if superType .__type == ' Interface' then
532
+ local implementors = context .schema :getImplementors (superType .name )
533
+ return implementors and implementors [context .schema :getType (subType .name )]
534
+ elseif superType .__type == ' Union' or
535
+ (superType .__type == ' Scalar' and
536
+ superType .subtype == ' InputUnion' ) then
537
+ -- false then
538
+ local types = superType .types
539
+ for i = 1 , # types do
540
+ if types [i ] == subType then
541
+ return true
542
+ end
543
+ end
513
544
514
- if superType .__type == ' List' then
515
- if subType .__type == ' List' then
516
- return isTypeSubTypeOf (subType .ofType , superType .ofType )
545
+ return false
517
546
end
518
547
519
- return false
520
- elseif subType .__type == ' List' then
521
548
return false
522
549
end
523
550
524
- if subType .__type ~= ' Object' then return false end
525
-
526
- if superType .__type == ' Interface' then
527
- local implementors = context .schema :getImplementors (superType .name )
528
- return implementors and implementors [context .schema :getType (subType .name )]
529
- elseif superType .__type == ' Union' then
530
- local types = superType .types
531
- for i = 1 , # types do
532
- if types [i ] == subType then
533
- return true
551
+ if not isTypeSubTypeOf (variableType , argumentType ) then
552
+ error (' Variable type mismatch' )
553
+ end
554
+ elseif argument .value .kind == ' inputObject' then
555
+ for _ , child in ipairs (argument .value .values ) do
556
+ if argumentType .__type == ' InputObject' then
557
+ local childArgumentType = argumentType .fields [child .name ].kind
558
+ recursiveValidateVariableType (child , childArgumentType )
559
+ elseif argumentType .__type == ' Scalar' and
560
+ argumentType .subtype == ' InputMap' then
561
+ local childArgumentType = argumentType .values
562
+ recursiveValidateVariableType (child , childArgumentType )
563
+ elseif argumentType .__type == ' Scalar' and
564
+ argumentType .subtype == ' InputUnion' then
565
+ local has_ok
566
+ local first_err
567
+ for _ , childArgumentType in ipairs (argumentType .types ) do
568
+ local ok , err = pcall (recursiveValidateVariableType , child ,
569
+ childArgumentType )
570
+ has_ok = has_ok or ok
571
+ first_err = first_err or graphql_utils .strip_error (err )
572
+ if ok then break end
573
+ end
574
+ if not has_ok then
575
+ error (first_err )
534
576
end
535
577
end
536
-
537
- return false
538
578
end
539
-
540
- return false
541
- end
542
-
543
- if not isTypeSubTypeOf (variableType , argumentType ) then
544
- error (' Variable type mismatch' )
545
579
end
546
580
end
581
+ recursiveValidateVariableType (argument , argumentType )
547
582
end
548
583
end
549
584
end
0 commit comments