Skip to content
This repository was archived by the owner on Apr 14, 2022. It is now read-only.

Commit 7df24d9

Browse files
SudoBoboTotktonada
authored andcommitted
Add support of union connections
Fixes #8. Backlog / related: #84, #85, #86, #88, #89 + some refactoring / code deduplication.
1 parent 4ded4d7 commit 7df24d9

14 files changed

+906
-213
lines changed

Makefile

+1
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ lint:
88
test/testdata/*.lua \
99
test/common/*.test.lua test/common/lua/*.lua \
1010
test/extra/*.test.lua \
11+
test/*.lua \
1112
--no-redefined --no-unused-args
1213

1314
.PHONY: test

graphql/accessor_general.lua

+104-61
Original file line numberDiff line numberDiff line change
@@ -559,6 +559,61 @@ local function build_index_parts_tree(indexes)
559559
return roots
560560
end
561561

562+
local function set_connection_index(c, c_name, c_type, collection_name,
563+
indexes, connection_indexes)
564+
assert(type(c.index_name) == 'string',
565+
'index_name must be a string, got ' .. type(c.index_name))
566+
567+
-- validate index_name against 'indexes'
568+
local index_meta = indexes[c.destination_collection]
569+
assert(type(index_meta) == 'table',
570+
'index_meta must be a table, got ' .. type(index_meta))
571+
572+
assert(type(collection_name) == 'string', 'collection_name expected to ' ..
573+
'be string, got ' .. type(collection_name))
574+
575+
-- validate connection parts are match or being prefix of index
576+
-- fields
577+
local i = 1
578+
local index_fields = index_meta[c.index_name].fields
579+
for _, part in ipairs(c.parts) do
580+
assert(type(part.source_field) == 'string',
581+
'part.source_field must be a string, got ' ..
582+
type(part.source_field))
583+
assert(type(part.destination_field) == 'string',
584+
'part.destination_field must be a string, got ' ..
585+
type(part.destination_field))
586+
assert(part.destination_field == index_fields[i],
587+
('connection "%s" of collection "%s" has destination parts that ' ..
588+
'is not prefix of the index "%s" parts ' ..
589+
'(destination collection - "%s")'):format(c_name, collection_name,
590+
c.index_name, c.destination_collection))
591+
i = i + 1
592+
end
593+
local parts_cnt = i - 1
594+
595+
-- partial index of an unique index is not guaranteed to being
596+
-- unique
597+
assert(c_type == '1:N' or parts_cnt == #index_fields,
598+
('1:1 connection "%s" of collection "%s" ' ..
599+
'has less fields than the index of "%s" collection ' ..
600+
'(cannot prove uniqueness of the partial index)'):format(c_name,
601+
collection_name, c.index_name, c.destination_collection))
602+
603+
-- validate connection type against index uniqueness (if provided)
604+
if index_meta.unique ~= nil then
605+
assert(c_type == '1:N' or index_meta.unique == true,
606+
('1:1 connection ("%s") cannot be implemented ' ..
607+
'on top of non-unique index ("%s")'):format(
608+
c_name, c.index_name))
609+
end
610+
611+
return {
612+
index_name = c.index_name,
613+
connection_type = c_type,
614+
}
615+
end
616+
562617
--- Build `connection_indexes` table (part of `index_cache`) to use in the
563618
--- @{get_index_name} function.
564619
---
@@ -581,60 +636,28 @@ local function build_connection_indexes(indexes, collections)
581636
assert(type(collections) == 'table', 'collections must be a table, got ' ..
582637
type(collections))
583638
local connection_indexes = {}
584-
for _, collection in pairs(collections) do
639+
for collection_name, collection in pairs(collections) do
585640
for _, c in ipairs(collection.connections) do
586-
if connection_indexes[c.destination_collection] == nil then
587-
connection_indexes[c.destination_collection] = {}
588-
end
589-
local index_name = c.index_name
590-
assert(type(index_name) == 'string',
591-
'index_name must be a string, got ' .. type(index_name))
641+
if c.destination_collection ~= nil then
642+
if connection_indexes[c.destination_collection] == nil then
643+
connection_indexes[c.destination_collection] = {}
644+
end
592645

593-
-- validate index_name against 'indexes'
594-
local index_meta = indexes[c.destination_collection]
595-
assert(type(index_meta) == 'table',
596-
'index_meta must be a table, got ' .. type(index_meta))
597-
598-
-- validate connection parts are match or being prefix of index
599-
-- fields
600-
local i = 1
601-
local index_fields = index_meta[c.index_name].fields
602-
for _, part in ipairs(c.parts) do
603-
assert(type(part.source_field) == 'string',
604-
'part.source_field must be a string, got ' ..
605-
type(part.source_field))
606-
assert(type(part.destination_field) == 'string',
607-
'part.destination_field must be a string, got ' ..
608-
type(part.destination_field))
609-
assert(part.destination_field == index_fields[i],
610-
('connection "%s" of collection "%s" ' ..
611-
'has destination parts that is not prefix of the index ' ..
612-
'"%s" parts'):format(c.name, c.destination_collection,
613-
c.index_name))
614-
i = i + 1
615-
end
616-
local parts_cnt = i - 1
617-
618-
-- partial index of an unique index is not guaranteed to being
619-
-- unique
620-
assert(c.type == '1:N' or parts_cnt == #index_fields,
621-
('1:1 connection "%s" of collection "%s" ' ..
622-
'has less fields than the index "%s" has (cannot prove ' ..
623-
'uniqueness of the partial index)'):format(c.name,
624-
c.destination_collection, c.index_name))
625-
626-
-- validate connection type against index uniqueness (if provided)
627-
if index_meta.unique ~= nil then
628-
assert(c.type == '1:N' or index_meta.unique == true,
629-
('1:1 connection ("%s") cannot be implemented ' ..
630-
'on top of non-unique index ("%s")'):format(
631-
c.name, index_name))
646+
connection_indexes[c.destination_collection][c.name] =
647+
set_connection_index(c, c.name, c.type, collection_name,
648+
indexes, connection_indexes)
632649
end
633650

634-
connection_indexes[c.destination_collection][c.name] = {
635-
index_name = index_name,
636-
connection_type = c.type,
637-
}
651+
if c.variants ~= nil then
652+
for _, v in ipairs(c.variants) do
653+
if connection_indexes[v.destination_collection] == nil then
654+
connection_indexes[v.destination_collection] = {}
655+
end
656+
connection_indexes[v.destination_collection][c.name] =
657+
set_connection_index(v, c.name, c.type, collection_name,
658+
indexes, connection_indexes)
659+
end
660+
end
638661
end
639662
end
640663
return connection_indexes
@@ -678,7 +701,7 @@ local function validate_collections(collections, schemas)
678701
type(schema_name))
679702
assert(schemas[schema_name] ~= nil,
680703
('cannot find schema "%s" for collection "%s"'):format(
681-
schema_name, collection_name))
704+
schema_name, collection_name))
682705
local connections = collection.connections
683706
assert(connections == nil or type(connections) == 'table',
684707
'collection.connections must be nil or table, got ' ..
@@ -688,16 +711,36 @@ local function validate_collections(collections, schemas)
688711
'connection must be a table, got ' .. type(connection))
689712
assert(type(connection.name) == 'string',
690713
'connection.name must be a string, got ' ..
691-
type(connection.name))
692-
assert(type(connection.destination_collection) == 'string',
693-
'connection.destination_collection must be a string, got ' ..
694-
type(connection.destination_collection))
695-
assert(type(connection.parts) == 'table',
696-
'connection.parts must be a string, got ' ..
697-
type(connection.parts))
698-
assert(type(connection.index_name) == 'string',
699-
'connection.index_name must be a string, got ' ..
700-
type(connection.index_name))
714+
type(connection.name))
715+
if connection.destination_collection then
716+
assert(type(connection.destination_collection) == 'string',
717+
'connection.destination_collection must be a string, got ' ..
718+
type(connection.destination_collection))
719+
assert(type(connection.parts) == 'table',
720+
'connection.parts must be a string, got ' ..
721+
type(connection.parts))
722+
assert(type(connection.index_name) == 'string',
723+
'connection.index_name must be a string, got ' ..
724+
type(connection.index_name))
725+
elseif connection.variants then
726+
for _, v in pairs(connection.variants) do
727+
assert(type(v.determinant) == 'table', "variant's " ..
728+
"determinant must be a table, got " ..
729+
type(v.determinant))
730+
assert(type(v.destination_collection) == 'string',
731+
'variant.destination_collection must be a string, ' ..
732+
'got ' .. type(v.destination_collection))
733+
assert(type(v.parts) == 'table',
734+
'variant.parts must be a table, got ' .. type(v.parts))
735+
assert(type(v.index_name) == 'string',
736+
'variant.index_name must be a string, got ' ..
737+
type(v.index_name))
738+
end
739+
else
740+
assert(false, ('connection "%s" of collection "%s" does not ' ..
741+
'have neither destination collection nor variants field'):
742+
format(connection.name, collection_name))
743+
end
701744
end
702745
end
703746
end

graphql/core/execute.lua

+15-4
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,13 @@ end
7070

7171
local evaluateSelections
7272

73-
local function completeValue(fieldType, result, subSelections, context)
73+
-- @param[opt] resolvedType a type to be used instead of one returned by
74+
-- `fieldType.resolveType(result)` in case when the `fieldType` is Interface or
75+
-- Union; that is needed to increase flexibility of an union type resolving
76+
-- (e.g. resolving by a parent object instead of a current object) via
77+
-- returning it from the `fieldType.resolve` function, which called before
78+
-- `resolvedType` and may need to determine the type itself for its needs
79+
local function completeValue(fieldType, result, subSelections, context, resolvedType)
7480
local fieldTypeName = fieldType.__type
7581

7682
if fieldTypeName == 'NonNull' then
@@ -111,7 +117,11 @@ local function completeValue(fieldType, result, subSelections, context)
111117
local fields = evaluateSelections(fieldType, result, subSelections, context)
112118
return next(fields) and fields or context.schema.__emptyObject
113119
elseif fieldTypeName == 'Interface' or fieldTypeName == 'Union' then
114-
local objectType = fieldType.resolveType(result)
120+
local objectType = resolvedType or fieldType.resolveType(result)
121+
while objectType.__type == 'NonNull' do
122+
objectType = objectType.ofType
123+
end
124+
115125
return evaluateSelections(objectType, result, subSelections, context)
116126
end
117127

@@ -151,10 +161,11 @@ local function getFieldEntry(objectType, object, fields, context)
151161
qcontext = context.qcontext
152162
}
153163

154-
local resolvedObject = (fieldType.resolve or defaultResolver)(object, arguments, info)
164+
-- resolvedType is optional return value
165+
local resolvedObject, resolvedType = (fieldType.resolve or defaultResolver)(object, arguments, info)
155166
local subSelections = query_util.mergeSelectionSets(fields)
156167

157-
return completeValue(fieldType.kind, resolvedObject, subSelections, context)
168+
return completeValue(fieldType.kind, resolvedObject, subSelections, context, resolvedType)
158169
end
159170

160171
evaluateSelections = function(objectType, object, selections, context)

graphql/core/query_util.lua

+2-2
Original file line numberDiff line numberDiff line change
@@ -74,15 +74,15 @@ function query_util.collectFields(objectType, selections, visitedFragments, resu
7474
end
7575
elseif selection.kind == 'inlineFragment' then
7676
if shouldIncludeNode(selection, context) and doesFragmentApply(selection, objectType, context) then
77-
collectFields(objectType, selection.selectionSet.selections, visitedFragments, result, context)
77+
query_util.collectFields(objectType, selection.selectionSet.selections, visitedFragments, result, context)
7878
end
7979
elseif selection.kind == 'fragmentSpread' then
8080
local fragmentName = selection.name.value
8181
if shouldIncludeNode(selection, context) and not visitedFragments[fragmentName] then
8282
visitedFragments[fragmentName] = true
8383
local fragment = context.fragmentMap[fragmentName]
8484
if fragment and shouldIncludeNode(fragment, context) and doesFragmentApply(fragment, objectType, context) then
85-
collectFields(objectType, fragment.selectionSet.selections, visitedFragments, result, context)
85+
query_util.collectFields(objectType, fragment.selectionSet.selections, visitedFragments, result, context)
8686
end
8787
end
8888
end

graphql/core/rules.lua

+8
Original file line numberDiff line numberDiff line change
@@ -323,6 +323,14 @@ function rules.fragmentSpreadIsPossible(node, context)
323323
local fragmentTypes = getTypes(fragmentType)
324324

325325
local valid = util.find(parentTypes, function(kind)
326+
local kind = kind
327+
-- Here is the check that type, mentioned in '... on some_type'
328+
-- conditional fragment expression is type of some field of parent object.
329+
-- In case of Union parent object and NonNull wrapped inner types
330+
-- graphql-lua missed unwrapping so we add it here
331+
while kind.__type == 'NonNull' do
332+
kind = kind.ofType
333+
end
326334
return fragmentTypes[kind]
327335
end)
328336

graphql/core/types.lua

+5-1
Original file line numberDiff line numberDiff line change
@@ -155,11 +155,15 @@ end
155155
function types.union(config)
156156
assert(type(config.name) == 'string', 'type name must be provided as a string')
157157
assert(type(config.types) == 'table', 'types table must be provided')
158+
if config.resolveType then
159+
assert(type(config.resolveType) == 'function', 'must provide resolveType as a function')
160+
end
158161

159162
local instance = {
160163
__type = 'Union',
161164
name = config.name,
162-
types = config.types
165+
types = config.types,
166+
resolveType = config.resolveType
163167
}
164168

165169
instance.nonNull = types.nonNull(instance)

0 commit comments

Comments
 (0)