Skip to content
This repository was archived by the owner on Apr 14, 2022. It is now read-only.

Commit 2bd33d7

Browse files
committed
add support of union connections, closes #8
1 parent 1408934 commit 2bd33d7

9 files changed

+858
-202
lines changed

graphql/accessor_general.lua

+103-60
Original file line numberDiff line numberDiff line change
@@ -559,6 +559,62 @@ local function build_index_parts_tree(indexes)
559559
return roots
560560
end
561561

562+
local function set_connection_index(c, c_name, c_type, collection_name,
563+
indexes, connection_indexes)
564+
assert(type(c.index_name) == 'string',
565+
'index_name must be a string, got ' .. type(c.index_name))
566+
567+
-- validate index_name against 'indexes'
568+
local index_meta = indexes[c.destination_collection]
569+
assert(type(index_meta) == 'table',
570+
'index_meta must be a table, got ' .. type(index_meta))
571+
572+
assert(type(collection_name) == 'string', 'collection_name expected to ' ..
573+
'be string, got ' .. type(collection_name))
574+
575+
-- validate connection parts are match or being prefix of index
576+
-- fields
577+
local i = 1
578+
local index_fields = index_meta[c.index_name].fields
579+
for _, part in ipairs(c.parts) do
580+
assert(type(part.source_field) == 'string',
581+
'part.source_field must be a string, got ' ..
582+
type(part.source_field))
583+
assert(type(part.destination_field) == 'string',
584+
'part.destination_field must be a string, got ' ..
585+
type(part.destination_field))
586+
assert(part.destination_field == index_fields[i],
587+
('connection "%s" of collection "%s" has destination parts that ' ..
588+
'is not prefix of the index "%s" parts ' ..
589+
'(destination collection - "%s")'):format(c_name, collection_name,
590+
c.index_name, c.destination_collection))
591+
i = i + 1
592+
end
593+
local parts_cnt = i - 1
594+
595+
-- partial index of an unique index is not guaranteed to being
596+
-- unique
597+
assert(c_type == '1:N' or parts_cnt == #index_fields,
598+
('1:1 connection "%s" of collection "%s" ' ..
599+
'has less fields than the index of "%s" collection ' ..
600+
'(cannot prove uniqueness of the partial index)'):format(c_name,
601+
collection_name, c.index_name, c.destination_collection))
602+
603+
-- validate connection type against index uniqueness (if provided)
604+
if index_meta.unique ~= nil then
605+
assert(c_type == '1:N' or index_meta.unique == true,
606+
('1:1 connection ("%s") cannot be implemented ' ..
607+
'on top of non-unique index ("%s")'):format(
608+
c_name, c.index_name))
609+
end
610+
611+
return {
612+
index_name = c.index_name,
613+
connection_type = c_type,
614+
}
615+
end
616+
617+
562618
--- Build `connection_indexes` table (part of `index_cache`) to use in the
563619
--- @{get_index_name} function.
564620
---
@@ -581,60 +637,28 @@ local function build_connection_indexes(indexes, collections)
581637
assert(type(collections) == 'table', 'collections must be a table, got ' ..
582638
type(collections))
583639
local connection_indexes = {}
584-
for _, collection in pairs(collections) do
640+
for collection_name, collection in pairs(collections) do
585641
for _, c in ipairs(collection.connections) do
586-
if connection_indexes[c.destination_collection] == nil then
587-
connection_indexes[c.destination_collection] = {}
588-
end
589-
local index_name = c.index_name
590-
assert(type(index_name) == 'string',
591-
'index_name must be a string, got ' .. type(index_name))
642+
if c.destination_collection ~= nil then
643+
if connection_indexes[c.destination_collection] == nil then
644+
connection_indexes[c.destination_collection] = {}
645+
end
592646

593-
-- validate index_name against 'indexes'
594-
local index_meta = indexes[c.destination_collection]
595-
assert(type(index_meta) == 'table',
596-
'index_meta must be a table, got ' .. type(index_meta))
597-
598-
-- validate connection parts are match or being prefix of index
599-
-- fields
600-
local i = 1
601-
local index_fields = index_meta[c.index_name].fields
602-
for _, part in ipairs(c.parts) do
603-
assert(type(part.source_field) == 'string',
604-
'part.source_field must be a string, got ' ..
605-
type(part.source_field))
606-
assert(type(part.destination_field) == 'string',
607-
'part.destination_field must be a string, got ' ..
608-
type(part.destination_field))
609-
assert(part.destination_field == index_fields[i],
610-
('connection "%s" of collection "%s" ' ..
611-
'has destination parts that is not prefix of the index ' ..
612-
'"%s" parts'):format(c.name, c.destination_collection,
613-
c.index_name))
614-
i = i + 1
615-
end
616-
local parts_cnt = i - 1
617-
618-
-- partial index of an unique index is not guaranteed to being
619-
-- unique
620-
assert(c.type == '1:N' or parts_cnt == #index_fields,
621-
('1:1 connection "%s" of collection "%s" ' ..
622-
'has less fields than the index "%s" has (cannot prove ' ..
623-
'uniqueness of the partial index)'):format(c.name,
624-
c.destination_collection, c.index_name))
625-
626-
-- validate connection type against index uniqueness (if provided)
627-
if index_meta.unique ~= nil then
628-
assert(c.type == '1:N' or index_meta.unique == true,
629-
('1:1 connection ("%s") cannot be implemented ' ..
630-
'on top of non-unique index ("%s")'):format(
631-
c.name, index_name))
647+
connection_indexes[c.destination_collection][c.name] =
648+
set_connection_index(c, c.name, c.type, collection_name,
649+
indexes, connection_indexes)
632650
end
633651

634-
connection_indexes[c.destination_collection][c.name] = {
635-
index_name = index_name,
636-
connection_type = c.type,
637-
}
652+
if c.variants ~= nil then
653+
for _, v in ipairs(c.variants) do
654+
if connection_indexes[v.destination_collection] == nil then
655+
connection_indexes[v.destination_collection] = {}
656+
end
657+
connection_indexes[v.destination_collection][c.name] =
658+
set_connection_index(v, c.name, c.type, collection_name,
659+
indexes, connection_indexes)
660+
end
661+
end
638662
end
639663
end
640664
return connection_indexes
@@ -675,29 +699,48 @@ local function validate_collections(collections, schemas)
675699
local schema_name = collection.schema_name
676700
assert(type(schema_name) == 'string',
677701
'collection.schema_name must be a string, got ' ..
678-
type(schema_name))
702+
type(schema_name))
679703
assert(schemas[schema_name] ~= nil,
680704
('cannot find schema "%s" for collection "%s"'):format(
681705
schema_name, collection_name))
682706
local connections = collection.connections
683707
assert(connections == nil or type(connections) == 'table',
684-
'collection.connections must be nil or table, got ' ..
685-
type(connections))
708+
'collection.connections must be nil or table, got ' ..
709+
type(connections))
686710
for _, connection in ipairs(connections) do
687711
assert(type(connection) == 'table',
688712
'connection must be a table, got ' .. type(connection))
689713
assert(type(connection.name) == 'string',
690714
'connection.name must be a string, got ' ..
691-
type(connection.name))
692-
assert(type(connection.destination_collection) == 'string',
715+
type(connection.name))
716+
if connection.destination_collection then
717+
assert(type(connection.destination_collection) == 'string',
693718
'connection.destination_collection must be a string, got ' ..
694-
type(connection.destination_collection))
695-
assert(type(connection.parts) == 'table',
719+
type(connection.destination_collection))
720+
assert(type(connection.parts) == 'table',
696721
'connection.parts must be a string, got ' ..
697-
type(connection.parts))
698-
assert(type(connection.index_name) == 'string',
722+
type(connection.parts))
723+
assert(type(connection.index_name) == 'string',
699724
'connection.index_name must be a string, got ' ..
700-
type(connection.index_name))
725+
type(connection.index_name))
726+
elseif connection.variants then
727+
for _, v in pairs(connection.variants) do
728+
assert(type(v.determinant) == 'table', 'variant\'s ' ..
729+
'determinant must be a table, got ' ..
730+
type(v.determinant))
731+
assert(type(v.destination_collection) == 'string',
732+
'variant.destination_collection must be a string, ' ..
733+
'got ' .. type(v.destination_collection))
734+
assert(type(v.parts) == 'table',
735+
'variant.parts must be a table, got ' .. type(v.parts))
736+
assert(type(v.index_name) == 'string',
737+
'variant.index_name must be a string, got ' ..
738+
type(v.index_name))
739+
end
740+
else
741+
assert(false, ('collection doesn\'t have neither destination' ..
742+
'collection nor variants field'))
743+
end
701744
end
702745
end
703746
end

graphql/core/query_util.lua

+2-2
Original file line numberDiff line numberDiff line change
@@ -74,15 +74,15 @@ function query_util.collectFields(objectType, selections, visitedFragments, resu
7474
end
7575
elseif selection.kind == 'inlineFragment' then
7676
if shouldIncludeNode(selection, context) and doesFragmentApply(selection, objectType, context) then
77-
collectFields(objectType, selection.selectionSet.selections, visitedFragments, result, context)
77+
query_util.collectFields(objectType, selection.selectionSet.selections, visitedFragments, result, context)
7878
end
7979
elseif selection.kind == 'fragmentSpread' then
8080
local fragmentName = selection.name.value
8181
if shouldIncludeNode(selection, context) and not visitedFragments[fragmentName] then
8282
visitedFragments[fragmentName] = true
8383
local fragment = context.fragmentMap[fragmentName]
8484
if fragment and shouldIncludeNode(fragment, context) and doesFragmentApply(fragment, objectType, context) then
85-
collectFields(objectType, fragment.selectionSet.selections, visitedFragments, result, context)
85+
query_util.collectFields(objectType, fragment.selectionSet.selections, visitedFragments, result, context)
8686
end
8787
end
8888
end

graphql/core/types.lua

+3-1
Original file line numberDiff line numberDiff line change
@@ -155,11 +155,13 @@ end
155155
function types.union(config)
156156
assert(type(config.name) == 'string', 'type name must be provided as a string')
157157
assert(type(config.types) == 'table', 'types table must be provided')
158+
assert(type(config.resolveType) == 'function', 'must provide resolveType as a function')
158159

159160
local instance = {
160161
__type = 'Union',
161162
name = config.name,
162-
types = config.types
163+
types = config.types,
164+
resolveType = config.resolveType
163165
}
164166

165167
instance.nonNull = types.nonNull(instance)

0 commit comments

Comments
 (0)