|
| 1 | +local yaml = require('yaml') |
1 | 2 | local path = (...):gsub('%.[^%.]+$', '')
|
2 | 3 | local types = require(path .. '.types')
|
3 | 4 | local util = require(path .. '.util')
|
4 |
| -local schema = require(path .. '.schema') |
5 | 5 | local introspection = require(path .. '.introspection')
|
6 | 6 | local query_util = require(path .. '.query_util')
|
| 7 | +local graphql_utils = require('graphql.utils') |
7 | 8 |
|
8 | 9 | local function getParentField(context, name, count)
|
9 | 10 | if introspection.fieldMap[name] then return introspection.fieldMap[name] end
|
@@ -162,7 +163,8 @@ function rules.unambiguousSelections(node, context)
|
162 | 163 |
|
163 | 164 | validateField(key, fieldEntry)
|
164 | 165 | elseif selection.kind == 'inlineFragment' then
|
165 |
| - local parentType = selection.typeCondition and context.schema:getType(selection.typeCondition.name.value) or parentType |
| 166 | + local parentType = selection.typeCondition and context.schema:getType( |
| 167 | + selection.typeCondition.name.value) or parentType |
166 | 168 | validateSelectionSet(selection.selectionSet, parentType)
|
167 | 169 | elseif selection.kind == 'fragmentSpread' then
|
168 | 170 | local fragmentDefinition = context.fragmentMap[selection.name.value]
|
@@ -436,117 +438,192 @@ function rules.variablesAreDefined(node, context)
|
436 | 438 | end
|
437 | 439 | end
|
438 | 440 |
|
439 |
| -function rules.variableUsageAllowed(node, context) |
440 |
| - if context.currentOperation then |
441 |
| - local variableMap = {} |
442 |
| - for _, definition in ipairs(context.currentOperation.variableDefinitions or {}) do |
443 |
| - variableMap[definition.variable.name.value] = definition |
444 |
| - end |
445 |
| - |
446 |
| - local arguments |
447 |
| - |
448 |
| - if node.kind == 'field' then |
449 |
| - arguments = { [node.name.value] = node.arguments } |
450 |
| - elseif node.kind == 'fragmentSpread' then |
451 |
| - local seen = {} |
452 |
| - local function collectArguments(referencedNode) |
453 |
| - if referencedNode.kind == 'selectionSet' then |
454 |
| - for _, selection in ipairs(referencedNode.selections) do |
455 |
| - if not seen[selection] then |
456 |
| - seen[selection] = true |
457 |
| - collectArguments(selection) |
458 |
| - end |
459 |
| - end |
460 |
| - elseif referencedNode.kind == 'field' and referencedNode.arguments then |
461 |
| - local fieldName = referencedNode.name.value |
462 |
| - arguments[fieldName] = arguments[fieldName] or {} |
463 |
| - for _, argument in ipairs(referencedNode.arguments) do |
464 |
| - table.insert(arguments[fieldName], argument) |
465 |
| - end |
466 |
| - elseif referencedNode.kind == 'inlineFragment' then |
467 |
| - return collectArguments(referencedNode.selectionSet) |
468 |
| - elseif referencedNode.kind == 'fragmentSpread' then |
469 |
| - local fragment = context.fragmentMap[referencedNode.name.value] |
470 |
| - return fragment and collectArguments(fragment.selectionSet) |
471 |
| - end |
| 441 | +-- {{{ variableUsageAllowed |
| 442 | + |
| 443 | +local function collectArguments(referencedNode, context, seen, arguments) |
| 444 | + if referencedNode.kind == 'selectionSet' then |
| 445 | + for _, selection in ipairs(referencedNode.selections) do |
| 446 | + if not seen[selection] then |
| 447 | + seen[selection] = true |
| 448 | + collectArguments(selection, context, seen, arguments) |
472 | 449 | end
|
| 450 | + end |
| 451 | + elseif referencedNode.kind == 'field' and referencedNode.arguments then |
| 452 | + local fieldName = referencedNode.name.value |
| 453 | + arguments[fieldName] = arguments[fieldName] or {} |
| 454 | + for _, argument in ipairs(referencedNode.arguments) do |
| 455 | + table.insert(arguments[fieldName], argument) |
| 456 | + end |
| 457 | + elseif referencedNode.kind == 'inlineFragment' then |
| 458 | + return collectArguments(referencedNode.selectionSet, context, seen, |
| 459 | + arguments) |
| 460 | + elseif referencedNode.kind == 'fragmentSpread' then |
| 461 | + local fragment = context.fragmentMap[referencedNode.name.value] |
| 462 | + return fragment and collectArguments(fragment.selectionSet, context, seen, |
| 463 | + arguments) |
| 464 | + end |
| 465 | +end |
| 466 | + |
| 467 | +-- http://facebook.github.io/graphql/June2018/#AreTypesCompatible() |
| 468 | +local function isTypeSubTypeOf(subType, superType, context) |
| 469 | + if subType == superType then return true end |
| 470 | + |
| 471 | + if superType.__type == 'NonNull' then |
| 472 | + if subType.__type == 'NonNull' then |
| 473 | + return isTypeSubTypeOf(subType.ofType, superType.ofType, context) |
| 474 | + end |
| 475 | + |
| 476 | + return false |
| 477 | + elseif subType.__type == 'NonNull' then |
| 478 | + return isTypeSubTypeOf(subType.ofType, superType, context) |
| 479 | + end |
| 480 | + |
| 481 | + if superType.__type == 'List' then |
| 482 | + if subType.__type == 'List' then |
| 483 | + return isTypeSubTypeOf(subType.ofType, superType.ofType, context) |
| 484 | + end |
| 485 | + |
| 486 | + return false |
| 487 | + elseif subType.__type == 'List' then |
| 488 | + return false |
| 489 | + end |
473 | 490 |
|
474 |
| - local fragment = context.fragmentMap[node.name.value] |
475 |
| - if fragment then |
476 |
| - arguments = {} |
477 |
| - collectArguments(fragment.selectionSet) |
| 491 | + if superType.__type == 'Scalar' and superType.subtype == 'InputUnion' then |
| 492 | + local types = superType.types |
| 493 | + for i = 1, #types do |
| 494 | + if types[i] == subType then |
| 495 | + return true |
478 | 496 | end
|
479 | 497 | end
|
480 | 498 |
|
481 |
| - if not arguments then return end |
| 499 | + return false |
| 500 | + end |
| 501 | + |
| 502 | + return false |
| 503 | +end |
| 504 | + |
| 505 | +local function getTypeName(t) |
| 506 | + if t.name ~= nil then |
| 507 | + if t.name == 'Scalar' and t.subtype == 'InputMap' then |
| 508 | + return ('InputMap(%s)'):format(getTypeName(t.values)) |
| 509 | + elseif t.name == 'Scalar' and t.subtype == 'InputUnion' then |
| 510 | + local typeNames = {} |
| 511 | + for _, child in ipairs(t.types) do |
| 512 | + table.insert(typeNames, getTypeName(child)) |
| 513 | + end |
| 514 | + return ('InputUnion(%s)'):format(table.concat(typeNames, ',')) |
| 515 | + end |
| 516 | + return t.name |
| 517 | + elseif t.__type == 'NonNull' then |
| 518 | + return ('NonNull(%s)'):format(getTypeName(t.ofType)) |
| 519 | + elseif t.__type == 'List' then |
| 520 | + return ('List(%s)'):format(getTypeName(t.ofType)) |
| 521 | + end |
482 | 522 |
|
483 |
| - for field in pairs(arguments) do |
484 |
| - local parentField = getParentField(context, field) |
485 |
| - for i = 1, #arguments[field] do |
486 |
| - local argument = arguments[field][i] |
487 |
| - if argument.value.kind == 'variable' then |
488 |
| - local argumentType = parentField.arguments[argument.name.value] |
| 523 | + local orig_encode_use_tostring = yaml.cfg.encode_use_tostring |
| 524 | + local err = ('Internal error: unknown type:\n%s'):format(yaml.encode(t)) |
| 525 | + yaml.cfg({encode_use_tostring = orig_encode_use_tostring}) |
| 526 | + error(err) |
| 527 | +end |
489 | 528 |
|
490 |
| - local variableName = argument.value.name.value |
491 |
| - local variableDefinition = variableMap[variableName] |
492 |
| - local hasDefault = variableDefinition.defaultValue ~= nil |
| 529 | +local function isVariableTypesValid(argument, argumentType, context, |
| 530 | + variableMap) |
| 531 | + if argument.value.kind == 'variable' then |
| 532 | + -- found a variable, check types compatibility |
| 533 | + local variableName = argument.value.name.value |
| 534 | + local variableDefinition = variableMap[variableName] |
| 535 | + local hasDefault = variableDefinition.defaultValue ~= nil |
493 | 536 |
|
494 |
| - local variableType = query_util.typeFromAST(variableDefinition.type, |
495 |
| - context.schema) |
| 537 | + local variableType = query_util.typeFromAST(variableDefinition.type, |
| 538 | + context.schema) |
496 | 539 |
|
497 |
| - if hasDefault and variableType.__type ~= 'NonNull' then |
498 |
| - variableType = types.nonNull(variableType) |
499 |
| - end |
| 540 | + if hasDefault and variableType.__type ~= 'NonNull' then |
| 541 | + variableType = types.nonNull(variableType) |
| 542 | + end |
500 | 543 |
|
501 |
| - local function isTypeSubTypeOf(subType, superType) |
502 |
| - if subType == superType then return true end |
503 |
| - |
504 |
| - if superType.__type == 'NonNull' then |
505 |
| - if subType.__type == 'NonNull' then |
506 |
| - return isTypeSubTypeOf(subType.ofType, superType.ofType) |
507 |
| - end |
508 |
| - |
509 |
| - return false |
510 |
| - elseif subType.__type == 'NonNull' then |
511 |
| - return isTypeSubTypeOf(subType.ofType, superType) |
512 |
| - end |
513 |
| - |
514 |
| - if superType.__type == 'List' then |
515 |
| - if subType.__type == 'List' then |
516 |
| - return isTypeSubTypeOf(subType.ofType, superType.ofType) |
517 |
| - end |
518 |
| - |
519 |
| - return false |
520 |
| - elseif subType.__type == 'List' then |
521 |
| - return false |
522 |
| - end |
523 |
| - |
524 |
| - if subType.__type ~= 'Object' then return false end |
525 |
| - |
526 |
| - if superType.__type == 'Interface' then |
527 |
| - local implementors = context.schema:getImplementors(superType.name) |
528 |
| - return implementors and implementors[context.schema:getType(subType.name)] |
529 |
| - elseif superType.__type == 'Union' then |
530 |
| - local types = superType.types |
531 |
| - for i = 1, #types do |
532 |
| - if types[i] == subType then |
533 |
| - return true |
534 |
| - end |
535 |
| - end |
536 |
| - |
537 |
| - return false |
538 |
| - end |
539 |
| - |
540 |
| - return false |
541 |
| - end |
| 544 | + if not isTypeSubTypeOf(variableType, argumentType, context) then |
| 545 | + return false, ('Variable "%s" type mismatch: the variable type "%s" ' .. |
| 546 | + 'is not compatible with the argument type "%s"'):format(variableName, |
| 547 | + getTypeName(variableType), getTypeName(argumentType)) |
| 548 | + end |
| 549 | + elseif argument.value.kind == 'inputObject' then |
| 550 | + -- find variables deeper |
| 551 | + for _, child in ipairs(argument.value.values) do |
| 552 | + local isInputObject = argumentType.__type == 'InputObject' |
| 553 | + local isInputMap = argumentType.__type == 'Scalar' and |
| 554 | + argumentType.subtype == 'InputMap' |
| 555 | + local isInputUnion = argumentType.__type == 'Scalar' and |
| 556 | + argumentType.subtype == 'InputUnion' |
| 557 | + |
| 558 | + if isInputObject then |
| 559 | + local childArgumentType = argumentType.fields[child.name].kind |
| 560 | + local ok, err = isVariableTypesValid(child, childArgumentType, context, |
| 561 | + variableMap) |
| 562 | + if not ok then return false, err end |
| 563 | + elseif isInputMap then |
| 564 | + local childArgumentType = argumentType.values |
| 565 | + local ok, err = isVariableTypesValid(child, childArgumentType, context, |
| 566 | + variableMap) |
| 567 | + if not ok then return false, err end |
| 568 | + elseif isInputUnion then |
| 569 | + local has_ok = false |
| 570 | + local first_err |
| 571 | + |
| 572 | + for _, childArgumentType in ipairs(argumentType.types) do |
| 573 | + local ok, err = isVariableTypesValid(child, |
| 574 | + childArgumentType, context, variableMap) |
| 575 | + has_ok = has_ok or ok |
| 576 | + first_err = first_err or graphql_utils.strip_error(err) |
| 577 | + if ok then break end |
| 578 | + end |
542 | 579 |
|
543 |
| - if not isTypeSubTypeOf(variableType, argumentType) then |
544 |
| - error('Variable type mismatch') |
545 |
| - end |
| 580 | + if not has_ok then |
| 581 | + return false, first_err |
546 | 582 | end
|
547 | 583 | end
|
548 | 584 | end
|
549 | 585 | end
|
| 586 | + return true |
| 587 | +end |
| 588 | + |
| 589 | +function rules.variableUsageAllowed(node, context) |
| 590 | + if not context.currentOperation then return end |
| 591 | + |
| 592 | + local variableMap = {} |
| 593 | + local variableDefinitions = context.currentOperation.variableDefinitions |
| 594 | + for _, definition in ipairs(variableDefinitions or {}) do |
| 595 | + variableMap[definition.variable.name.value] = definition |
| 596 | + end |
| 597 | + |
| 598 | + local arguments |
| 599 | + |
| 600 | + if node.kind == 'field' then |
| 601 | + arguments = { [node.name.value] = node.arguments } |
| 602 | + elseif node.kind == 'fragmentSpread' then |
| 603 | + local seen = {} |
| 604 | + local fragment = context.fragmentMap[node.name.value] |
| 605 | + if fragment then |
| 606 | + arguments = {} |
| 607 | + collectArguments(fragment.selectionSet, context, seen, arguments) |
| 608 | + end |
| 609 | + end |
| 610 | + |
| 611 | + if not arguments then return end |
| 612 | + |
| 613 | + for field in pairs(arguments) do |
| 614 | + local parentField = getParentField(context, field) |
| 615 | + for i = 1, #arguments[field] do |
| 616 | + local argument = arguments[field][i] |
| 617 | + local argumentType = parentField.arguments[argument.name.value] |
| 618 | + local ok, err = isVariableTypesValid(argument, argumentType, context, |
| 619 | + variableMap) |
| 620 | + if not ok then |
| 621 | + error(err) |
| 622 | + end |
| 623 | + end |
| 624 | + end |
550 | 625 | end
|
551 | 626 |
|
| 627 | +-- }}} |
| 628 | + |
552 | 629 | return rules
|
0 commit comments