diff --git a/Makefile b/Makefile index 11c701e..e916419 100644 --- a/Makefile +++ b/Makefile @@ -8,6 +8,7 @@ lint: test/testdata/*.lua \ test/common/*.test.lua test/common/lua/*.lua \ test/extra/*.test.lua \ + test/*.lua \ --no-redefined --no-unused-args .PHONY: test diff --git a/graphql/accessor_general.lua b/graphql/accessor_general.lua index 4b19f34..625a25e 100644 --- a/graphql/accessor_general.lua +++ b/graphql/accessor_general.lua @@ -559,6 +559,61 @@ local function build_index_parts_tree(indexes) return roots end +local function set_connection_index(c, c_name, c_type, collection_name, + indexes, connection_indexes) + assert(type(c.index_name) == 'string', + 'index_name must be a string, got ' .. type(c.index_name)) + + -- validate index_name against 'indexes' + local index_meta = indexes[c.destination_collection] + assert(type(index_meta) == 'table', + 'index_meta must be a table, got ' .. type(index_meta)) + + assert(type(collection_name) == 'string', 'collection_name expected to ' .. + 'be string, got ' .. type(collection_name)) + + -- validate connection parts are match or being prefix of index + -- fields + local i = 1 + local index_fields = index_meta[c.index_name].fields + for _, part in ipairs(c.parts) do + assert(type(part.source_field) == 'string', + 'part.source_field must be a string, got ' .. + type(part.source_field)) + assert(type(part.destination_field) == 'string', + 'part.destination_field must be a string, got ' .. + type(part.destination_field)) + assert(part.destination_field == index_fields[i], + ('connection "%s" of collection "%s" has destination parts that ' .. + 'is not prefix of the index "%s" parts ' .. + '(destination collection - "%s")'):format(c_name, collection_name, + c.index_name, c.destination_collection)) + i = i + 1 + end + local parts_cnt = i - 1 + + -- partial index of an unique index is not guaranteed to being + -- unique + assert(c_type == '1:N' or parts_cnt == #index_fields, + ('1:1 connection "%s" of collection "%s" ' .. + 'has less fields than the index of "%s" collection ' .. + '(cannot prove uniqueness of the partial index)'):format(c_name, + collection_name, c.index_name, c.destination_collection)) + + -- validate connection type against index uniqueness (if provided) + if index_meta.unique ~= nil then + assert(c_type == '1:N' or index_meta.unique == true, + ('1:1 connection ("%s") cannot be implemented ' .. + 'on top of non-unique index ("%s")'):format( + c_name, c.index_name)) + end + + return { + index_name = c.index_name, + connection_type = c_type, + } +end + --- Build `connection_indexes` table (part of `index_cache`) to use in the --- @{get_index_name} function. --- @@ -581,60 +636,28 @@ local function build_connection_indexes(indexes, collections) assert(type(collections) == 'table', 'collections must be a table, got ' .. type(collections)) local connection_indexes = {} - for _, collection in pairs(collections) do + for collection_name, collection in pairs(collections) do for _, c in ipairs(collection.connections) do - if connection_indexes[c.destination_collection] == nil then - connection_indexes[c.destination_collection] = {} - end - local index_name = c.index_name - assert(type(index_name) == 'string', - 'index_name must be a string, got ' .. type(index_name)) + if c.destination_collection ~= nil then + if connection_indexes[c.destination_collection] == nil then + connection_indexes[c.destination_collection] = {} + end - -- validate index_name against 'indexes' - local index_meta = indexes[c.destination_collection] - assert(type(index_meta) == 'table', - 'index_meta must be a table, got ' .. type(index_meta)) - - -- validate connection parts are match or being prefix of index - -- fields - local i = 1 - local index_fields = index_meta[c.index_name].fields - for _, part in ipairs(c.parts) do - assert(type(part.source_field) == 'string', - 'part.source_field must be a string, got ' .. - type(part.source_field)) - assert(type(part.destination_field) == 'string', - 'part.destination_field must be a string, got ' .. - type(part.destination_field)) - assert(part.destination_field == index_fields[i], - ('connection "%s" of collection "%s" ' .. - 'has destination parts that is not prefix of the index ' .. - '"%s" parts'):format(c.name, c.destination_collection, - c.index_name)) - i = i + 1 - end - local parts_cnt = i - 1 - - -- partial index of an unique index is not guaranteed to being - -- unique - assert(c.type == '1:N' or parts_cnt == #index_fields, - ('1:1 connection "%s" of collection "%s" ' .. - 'has less fields than the index "%s" has (cannot prove ' .. - 'uniqueness of the partial index)'):format(c.name, - c.destination_collection, c.index_name)) - - -- validate connection type against index uniqueness (if provided) - if index_meta.unique ~= nil then - assert(c.type == '1:N' or index_meta.unique == true, - ('1:1 connection ("%s") cannot be implemented ' .. - 'on top of non-unique index ("%s")'):format( - c.name, index_name)) + connection_indexes[c.destination_collection][c.name] = + set_connection_index(c, c.name, c.type, collection_name, + indexes, connection_indexes) end - connection_indexes[c.destination_collection][c.name] = { - index_name = index_name, - connection_type = c.type, - } + if c.variants ~= nil then + for _, v in ipairs(c.variants) do + if connection_indexes[v.destination_collection] == nil then + connection_indexes[v.destination_collection] = {} + end + connection_indexes[v.destination_collection][c.name] = + set_connection_index(v, c.name, c.type, collection_name, + indexes, connection_indexes) + end + end end end return connection_indexes @@ -678,7 +701,7 @@ local function validate_collections(collections, schemas) type(schema_name)) assert(schemas[schema_name] ~= nil, ('cannot find schema "%s" for collection "%s"'):format( - schema_name, collection_name)) + schema_name, collection_name)) local connections = collection.connections assert(connections == nil or type(connections) == 'table', 'collection.connections must be nil or table, got ' .. @@ -688,16 +711,36 @@ local function validate_collections(collections, schemas) 'connection must be a table, got ' .. type(connection)) assert(type(connection.name) == 'string', 'connection.name must be a string, got ' .. - type(connection.name)) - assert(type(connection.destination_collection) == 'string', - 'connection.destination_collection must be a string, got ' .. - type(connection.destination_collection)) - assert(type(connection.parts) == 'table', - 'connection.parts must be a string, got ' .. - type(connection.parts)) - assert(type(connection.index_name) == 'string', - 'connection.index_name must be a string, got ' .. - type(connection.index_name)) + type(connection.name)) + if connection.destination_collection then + assert(type(connection.destination_collection) == 'string', + 'connection.destination_collection must be a string, got ' .. + type(connection.destination_collection)) + assert(type(connection.parts) == 'table', + 'connection.parts must be a string, got ' .. + type(connection.parts)) + assert(type(connection.index_name) == 'string', + 'connection.index_name must be a string, got ' .. + type(connection.index_name)) + elseif connection.variants then + for _, v in pairs(connection.variants) do + assert(type(v.determinant) == 'table', "variant's " .. + "determinant must be a table, got " .. + type(v.determinant)) + assert(type(v.destination_collection) == 'string', + 'variant.destination_collection must be a string, ' .. + 'got ' .. type(v.destination_collection)) + assert(type(v.parts) == 'table', + 'variant.parts must be a table, got ' .. type(v.parts)) + assert(type(v.index_name) == 'string', + 'variant.index_name must be a string, got ' .. + type(v.index_name)) + end + else + assert(false, ('connection "%s" of collection "%s" does not ' .. + 'have neither destination collection nor variants field'): + format(connection.name, collection_name)) + end end end end diff --git a/graphql/core/execute.lua b/graphql/core/execute.lua index 53807e2..db27814 100644 --- a/graphql/core/execute.lua +++ b/graphql/core/execute.lua @@ -70,7 +70,13 @@ end local evaluateSelections -local function completeValue(fieldType, result, subSelections, context) +-- @param[opt] resolvedType a type to be used instead of one returned by +-- `fieldType.resolveType(result)` in case when the `fieldType` is Interface or +-- Union; that is needed to increase flexibility of an union type resolving +-- (e.g. resolving by a parent object instead of a current object) via +-- returning it from the `fieldType.resolve` function, which called before +-- `resolvedType` and may need to determine the type itself for its needs +local function completeValue(fieldType, result, subSelections, context, resolvedType) local fieldTypeName = fieldType.__type if fieldTypeName == 'NonNull' then @@ -111,7 +117,11 @@ local function completeValue(fieldType, result, subSelections, context) local fields = evaluateSelections(fieldType, result, subSelections, context) return next(fields) and fields or context.schema.__emptyObject elseif fieldTypeName == 'Interface' or fieldTypeName == 'Union' then - local objectType = fieldType.resolveType(result) + local objectType = resolvedType or fieldType.resolveType(result) + while objectType.__type == 'NonNull' do + objectType = objectType.ofType + end + return evaluateSelections(objectType, result, subSelections, context) end @@ -151,10 +161,11 @@ local function getFieldEntry(objectType, object, fields, context) qcontext = context.qcontext } - local resolvedObject = (fieldType.resolve or defaultResolver)(object, arguments, info) + -- resolvedType is optional return value + local resolvedObject, resolvedType = (fieldType.resolve or defaultResolver)(object, arguments, info) local subSelections = query_util.mergeSelectionSets(fields) - return completeValue(fieldType.kind, resolvedObject, subSelections, context) + return completeValue(fieldType.kind, resolvedObject, subSelections, context, resolvedType) end evaluateSelections = function(objectType, object, selections, context) diff --git a/graphql/core/query_util.lua b/graphql/core/query_util.lua index 7887878..9f2f7c3 100644 --- a/graphql/core/query_util.lua +++ b/graphql/core/query_util.lua @@ -74,7 +74,7 @@ function query_util.collectFields(objectType, selections, visitedFragments, resu end elseif selection.kind == 'inlineFragment' then if shouldIncludeNode(selection, context) and doesFragmentApply(selection, objectType, context) then - collectFields(objectType, selection.selectionSet.selections, visitedFragments, result, context) + query_util.collectFields(objectType, selection.selectionSet.selections, visitedFragments, result, context) end elseif selection.kind == 'fragmentSpread' then local fragmentName = selection.name.value @@ -82,7 +82,7 @@ function query_util.collectFields(objectType, selections, visitedFragments, resu visitedFragments[fragmentName] = true local fragment = context.fragmentMap[fragmentName] if fragment and shouldIncludeNode(fragment, context) and doesFragmentApply(fragment, objectType, context) then - collectFields(objectType, fragment.selectionSet.selections, visitedFragments, result, context) + query_util.collectFields(objectType, fragment.selectionSet.selections, visitedFragments, result, context) end end end diff --git a/graphql/core/rules.lua b/graphql/core/rules.lua index 61005ea..41ab022 100644 --- a/graphql/core/rules.lua +++ b/graphql/core/rules.lua @@ -323,6 +323,14 @@ function rules.fragmentSpreadIsPossible(node, context) local fragmentTypes = getTypes(fragmentType) local valid = util.find(parentTypes, function(kind) + local kind = kind + -- Here is the check that type, mentioned in '... on some_type' + -- conditional fragment expression is type of some field of parent object. + -- In case of Union parent object and NonNull wrapped inner types + -- graphql-lua missed unwrapping so we add it here + while kind.__type == 'NonNull' do + kind = kind.ofType + end return fragmentTypes[kind] end) diff --git a/graphql/core/types.lua b/graphql/core/types.lua index e24a30d..236ff69 100644 --- a/graphql/core/types.lua +++ b/graphql/core/types.lua @@ -155,11 +155,15 @@ end function types.union(config) assert(type(config.name) == 'string', 'type name must be provided as a string') assert(type(config.types) == 'table', 'types table must be provided') + if config.resolveType then + assert(type(config.resolveType) == 'function', 'must provide resolveType as a function') + end local instance = { __type = 'Union', name = config.name, - types = config.types + types = config.types, + resolveType = config.resolveType } instance.nonNull = types.nonNull(instance) diff --git a/graphql/tarantool_graphql.lua b/graphql/tarantool_graphql.lua index 3026f52..3bcb865 100644 --- a/graphql/tarantool_graphql.lua +++ b/graphql/tarantool_graphql.lua @@ -6,8 +6,32 @@ --- * GraphQL top level statement must be a collection name. Arguments for this --- statement match non-deducible field names of corresponding object and --- passed to an accessor function in the filter argument. +--- +--- Border cases: +--- +--- * Unions: as GraphQL specification says "...no fields may be queried on +--- Union type without the use of typed fragments." Tarantool_graphql +--- behaves this way. So 'common fields' are not supported. This does NOT +--- work: +--- +--- ``` +--- hero { +--- hero_id -- common field; does NOT work +--- ... on human { +--- name +--- } +--- ... on droid { +--- model +--- } +--- } +--- ``` +--- +--- (GraphQL spec: http://facebook.github.io/graphql/October2016/#sec-Unions) +--- Also, no arguments are currently allowed for fragments. +--- See issue about this (https://github.com/facebook/graphql/issues/204) local json = require('json') +local yaml = require('yaml') local parse = require('graphql.core.parse') local schema = require('graphql.core.schema') @@ -250,6 +274,367 @@ local function convert_record_fields(state, fields) return res end +local function args_from_destination_collection(state, collection, + connection_type) + if connection_type == '1:1' then + return state.object_arguments[collection] + elseif connection_type == '1:1*' then + return state.object_arguments[collection] + elseif connection_type == '1:N' then + return state.all_arguments[collection] + else + error('unknown connection type: ' .. tostring(connection_type)) + end +end + +local function specify_destination_type(destination_type, connection_type) + if connection_type == '1:1' then + return types.nonNull(destination_type) + elseif connection_type == '1:1*' then + return destination_type + elseif connection_type == '1:N' then + return types.nonNull(types.list(types.nonNull(destination_type))) + else + error('unknown connection type: ' .. tostring(connection_type)) + end +end + +local function parent_args_values(parent, connection_parts) + local destination_args_names = {} + local destination_args_values = {} + for _, part in ipairs(connection_parts) do + assert(type(part.source_field) == 'string', + 'part.source_field must be a string, got ' .. + type(part.destination_field)) + assert(type(part.destination_field) == 'string', + 'part.destination_field must be a string, got ' .. + type(part.destination_field)) + + destination_args_names[#destination_args_names + 1] = + part.destination_field + local value = parent[part.source_field] + destination_args_values[#destination_args_values + 1] = value + end + + return destination_args_names, destination_args_values +end + +-- Check FULL match constraint before request of +-- destination object(s). Note that connection key parts +-- can be prefix of index key parts. Zero parts count +-- considered as ok by this check. +local function are_all_parts_null(parent, connection_parts) + local are_all_parts_null = true + local are_all_parts_non_null = true + for _, part in ipairs(connection_parts) do + local value = parent[part.source_field] + + if value ~= nil then -- nil or box.NULL + are_all_parts_null = false + else + are_all_parts_non_null = false + end + end + + local ok = are_all_parts_null or are_all_parts_non_null + if not ok then -- avoid extra json.encode() + assert(ok, + 'FULL MATCH constraint was failed: connection ' .. + 'key parts must be all non-nulls or all nulls; ' .. + 'object: ' .. json.encode(parent)) + end + + return are_all_parts_null +end + +local function separate_args_instance(args_instance, connection_args, + connection_list_args) + local object_args_instance = {} + local list_args_instance = {} + for k, v in pairs(args_instance) do + if connection_list_args[k] ~= nil then + list_args_instance[k] = v + elseif connection_args[k] ~= nil then + object_args_instance[k] = v + else + error(('cannot found "%s" field ("%s" value) ' .. + 'within allowed fields'):format(tostring(k), + tostring(v))) + end + end + return object_args_instance, list_args_instance +end + +--- The function converts passed simple connection to a field of GraphQL type. +--- +--- @tparam table state for read state.accessor and previously filled +--- state.nullable_collection_types (those are gql types) +--- @tparam table avro_schema input avro-schema +--- @tparam[opt] table collection table with schema_name, connections fields +--- described a collection (e.g. tarantool's spaces) +--- +--- @tparam table state for for collection types +--- @tparam table connection simple connection to create field on +--- @tparam table collection_name name of the collection which has given +--- connection +local function convert_simple_connection(state, connection, collection_name) + local c = connection + assert(type(c.destination_collection) == 'string', + 'connection.destination_collection must be a string, got ' .. + type(c.destination_collection)) + assert(type(c.parts) == 'table', + 'connection.parts must be a table, got ' .. type(c.parts)) + + -- gql type of connection field + local destination_type = + state.nullable_collection_types[c.destination_collection] + + assert(destination_type ~= nil, + ('destination_type (named %s) must not be nil'):format( + c.destination_collection)) + + + local c_args = args_from_destination_collection(state, + c.destination_collection, c.type) + destination_type = specify_destination_type(destination_type, c.type) + + local c_list_args = state.list_arguments[c.destination_collection] + + local field = { + name = c.name, + kind = destination_type, + arguments = c_args, + resolve = function(parent, args_instance, info) + local destination_args_names, destination_args_values = + parent_args_values(parent, c.parts) + + -- Avoid non-needed index lookup on a destination + -- collection when all connection parts are null: + -- * return null for 1:1* connection; + -- * return {} for 1:N connection (except the case when + -- source collection is the Query pseudo-collection). + if collection_name ~= 'Query' and are_all_parts_null(parent, c.parts) + then + if c.type ~= '1:1*' and c.type ~= '1:N' then + -- `if` is to avoid extra json.encode + assert(c.type == '1:1*' or c.type == '1:N', + ('only 1:1* or 1:N connections can have ' .. + 'all key parts null; parent is %s from ' .. + 'collection "%s"'):format(json.encode(parent), + tostring(collection_name))) + end + return c.type == '1:N' and {} or nil + end + + local from = { + collection_name = collection_name, + connection_name = c.name, + destination_args_names = destination_args_names, + destination_args_values = destination_args_values, + } + local extra = { + qcontext = info.qcontext + } + + -- object_args_instance will be passed to 'filter' + -- list_args_instance will be passed to 'args' + local object_args_instance, list_args_instance = + separate_args_instance(args_instance, c_args, c_list_args) + + local objs = state.accessor:select(parent, + c.destination_collection, from, + object_args_instance, list_args_instance, extra) + assert(type(objs) == 'table', + 'objs list received from an accessor ' .. + 'must be a table, got ' .. type(objs)) + if c.type == '1:1' or c.type == '1:1*' then + -- we expect here exactly one object even for 1:1* + -- connections because we processed all-parts-are-null + -- situation above + assert(#objs == 1, 'expect one matching object, got ' .. + tostring(#objs)) + return objs[1] + else -- c.type == '1:N' + return objs + end + end, + } + + return field +end + +--- The function converts passed union connection to a field of GraphQL type. +--- It combines destination collections of passed union connection into +--- the Union GraphQL type. +--- (destination collections are 'types' of a 'Union' in GraphQL). +--- +--- @tparam table state for collection types +--- @tparam table connection union connection to create field on +--- @tparam table collection_name name of the collection which has given +--- connection +local function convert_union_connection(state, connection, collection_name) + local c = connection + local union_types = {} + local collection_to_arguments = {} + local collection_to_list_arguments = {} + + for _, v in ipairs(c.variants) do + assert(v.determinant, 'each variant should have a determinant') + assert(type(v.determinant) == 'table', 'variant\'s determinant ' .. + 'must end be a table, got ' .. type(v.determinant)) + assert(type(v.destination_collection) == 'string', + 'variant.destination_collection must be a string, got ' .. + type(v.destination_collection)) + assert(type(v.parts) == 'table', + 'variant.parts must be a table, got ' .. type(v.parts)) + + local destination_type = + state.nullable_collection_types[v.destination_collection] + assert(destination_type ~= nil, + ('destination_type (named %s) must not be nil'):format( + v.destination_collection)) + + local v_args = args_from_destination_collection(state, + v.destination_collection, c.type) + destination_type = specify_destination_type(destination_type, c.type) + + local v_list_args = state.list_arguments[v.destination_collection] + + union_types[#union_types + 1] = destination_type + + collection_to_arguments[v.destination_collection] = v_args + collection_to_list_arguments[v.destination_collection] = v_list_args + end + + local determinant_keys = utils.get_keys(c.variants[1].determinant) + + local resolve_variant = function (parent) + assert(utils.do_have_keys(parent, determinant_keys), + ('Parent object of union object doesn\'t have determinant ' .. + 'fields which are necessary to determine which resolving ' .. + 'variant should be used. Union parent object:\n"%s"\n' .. + 'Determinant keys:\n"%s"'): + format(yaml.encode(parent), yaml.encode(determinant_keys))) + + local variant_num + local resulting_variant + for i, variant in ipairs(c.variants) do + variant_num = i + local is_match = utils.is_subtable(parent, variant.determinant) + + if is_match then + resulting_variant = variant + break + end + end + + assert(resulting_variant, ('Variant resolving failed.'.. + 'Parent object: "%s"\n'):format(yaml.encode(parent))) + return resulting_variant, variant_num + end + + local field = { + name = c.name, + kind = types.union({ + name = c.name, + types = union_types, + }), + arguments = nil, -- see Border cases/Unions at the top of the file + resolve = function(parent, args_instance, info) + local v, variant_num = resolve_variant(parent) + local destination_type = union_types[variant_num] + local destination_collection = + state.nullable_collection_types[v.destination_collection] + local destination_args_names, destination_args_values = + parent_args_values(parent, v.parts) + + -- Avoid non-needed index lookup on a destination + -- collection when all connection parts are null: + -- * return null for 1:1* connection; + -- * return {} for 1:N connection (except the case when + -- source collection is the Query pseudo-collection). + if collection_name ~= 'Query' and are_all_parts_null(parent, v.parts) + then + if c.type ~= '1:1*' and c.type ~= '1:N' then + -- `if` is to avoid extra json.encode + assert(c.type == '1:1*' or c.type == '1:N', + ('only 1:1* or 1:N connections can have ' .. + 'all key parts null; parent is %s from ' .. + 'collection "%s"'):format(json.encode(parent), + tostring(collection_name))) + end + return c.type == '1:N' and {} or nil, destination_type + end + + local from = { + collection_name = collection_name, + connection_name = c.name, + destination_args_names = destination_args_names, + destination_args_values = destination_args_values, + } + local extra = { + qcontext = info.qcontext + } + + local c_args = collection_to_arguments[destination_collection] + local c_list_args = collection_to_list_arguments[destination_collection] + + --object_args_instance -- passed to 'filter' + --list_args_instance -- passed to 'args' + + local object_args_instance, list_args_instance = + separate_args_instance(args_instance, c_args, c_list_args) + + local objs = state.accessor:select(parent, + v.destination_collection, from, + object_args_instance, list_args_instance, extra) + assert(type(objs) == 'table', + 'objs list received from an accessor ' .. + 'must be a table, got ' .. type(objs)) + if c.type == '1:1' or c.type == '1:1*' then + -- we expect here exactly one object even for 1:1* + -- connections because we processed all-parts-are-null + -- situation above + assert(#objs == 1, 'expect one matching object, got ' .. + tostring(#objs)) + return objs[1], destination_type + else -- c.type == '1:N' + return objs, destination_type + end + end + } + return field +end + +--- The function converts passed connection to a field of GraphQL type. +--- +--- @tparam table state for read state.accessor and previously filled +--- state.types (state.types are gql types) +--- @tparam table connection connection to create field on +--- @tparam table collection_name name of the collection which have given +--- connection +--- @treturn table simple and union connection depending on the type of +--- input connection +local convert_connection_to_field = function(state, connection, collection_name) + assert(type(connection.type) == 'string', + 'connection.type must be a string, got ' .. type(connection.type)) + assert(connection.type == '1:1' or connection.type == '1:1*' or + connection.type == '1:N', 'connection.type must be 1:1, 1:1* or 1:N, '.. + 'got ' .. connection.type) + assert(type(connection.name) == 'string', + 'connection.name must be a string, got ' .. type(connection.name)) + assert(connection.destination_collection or connection.variants, + 'connection must either destination_collection or variatns field') + + if connection.destination_collection then + return convert_simple_connection(state, connection, collection_name) + end + + if connection.variants then + return convert_union_connection(state, connection, collection_name) + end +end + --- The function converts passed avro-schema to a GraphQL type. --- --- @tparam table state for read state.accessor and previously filled @@ -282,7 +667,7 @@ gql_type = function(state, avro_schema, collection, collection_name) (collection ~= nil and collection_name ~= nil), ('collection and collection_name must be nils or ' .. 'non-nils simultaneously, got: %s and %s'):format(type(collection), - type(collection_name))) + type(collection_name))) local accessor = state.accessor assert(accessor ~= nil, 'state.accessor must not be nil') @@ -305,143 +690,7 @@ gql_type = function(state, avro_schema, collection, collection_name) -- if collection param is passed then go over all connections for _, c in ipairs((collection or {}).connections or {}) do - assert(type(c.type) == 'string', - 'connection.type must be a string, got ' .. type(c.type)) - assert(c.type == '1:1' or c.type == '1:1*' or c.type == '1:N', - 'connection.type must be 1:1, 1:1* or 1:N, got ' .. c.type) - assert(type(c.name) == 'string', - 'connection.name must be a string, got ' .. type(c.name)) - assert(type(c.destination_collection) == 'string', - 'connection.destination_collection must be a string, got ' .. - type(c.destination_collection)) - assert(type(c.parts) == 'table', - 'connection.parts must be a string, got ' .. type(c.parts)) - - -- gql type of connection field - local destination_type = - state.nullable_collection_types[c.destination_collection] - assert(destination_type ~= nil, - ('destination_type (named %s) must not be nil'):format( - c.destination_collection)) - - local c_args - if c.type == '1:1' then - destination_type = types.nonNull(destination_type) - c_args = state.object_arguments[c.destination_collection] - elseif c.type == '1:1*' then - c_args = state.object_arguments[c.destination_collection] - elseif c.type == '1:N' then - destination_type = types.nonNull(types.list(types.nonNull( - destination_type))) - c_args = state.all_arguments[c.destination_collection] - else - error('unknown connection type: ' .. tostring(c.type)) - end - - local c_list_args = state.list_arguments[c.destination_collection] - - fields[c.name] = { - name = c.name, - kind = destination_type, - arguments = c_args, - resolve = function(parent, args_instance, info) - local destination_args_names = {} - local destination_args_values = {} - local are_all_parts_non_null = true - local are_all_parts_null = true - - for _, part in ipairs(c.parts) do - assert(type(part.source_field) == 'string', - 'part.source_field must be a string, got ' .. - type(part.destination_field)) - assert(type(part.destination_field) == 'string', - 'part.destination_field must be a string, got ' .. - type(part.destination_field)) - - destination_args_names[#destination_args_names + 1] = - part.destination_field - - local value = parent[part.source_field] - destination_args_values[#destination_args_values + 1] = - value - - if value ~= nil then -- nil or box.NULL - are_all_parts_null = false - else - are_all_parts_non_null = false - end - end - - -- Check FULL match constraint before request of - -- destination object(s). Note that connection key parts - -- can be prefix of index key parts. Zero parts count - -- considered as ok by this check. - local ok = are_all_parts_null or are_all_parts_non_null - if not ok then -- avoid extra json.encode() - assert(ok, - 'FULL MATCH constraint was failed: connection ' .. - 'key parts must be all non-nulls or all nulls; ' .. - 'object: ' .. json.encode(parent)) - end - - -- Avoid non-needed index lookup on a destination - -- collection when all connection parts are null: - -- * return null for 1:1* connection; - -- * return {} for 1:N connection (except the case when - -- source collection is the Query pseudo-collection). - if collection_name ~= 'Query' and are_all_parts_null then - if c.type ~= '1:1*' and c.type ~= '1:N' then - -- `if` is to avoid extra json.encode - assert(c.type == '1:1*' or c.type == '1:N', - ('only 1:1* or 1:N connections can have ' .. - 'all key parts null; parent is %s from ' .. - 'collection "%s"'):format(json.encode(parent), - tostring(collection_name))) - end - return c.type == '1:N' and {} or nil - end - - local from = { - collection_name = collection_name, - connection_name = c.name, - destination_args_names = destination_args_names, - destination_args_values = destination_args_values, - } - local extra = { - qcontext = info.qcontext - } - local object_args_instance = {} -- passed to 'filter' - local list_args_instance = {} -- passed to 'args' - for k, v in pairs(args_instance) do - if c_list_args[k] ~= nil then - list_args_instance[k] = v - elseif c_args[k] ~= nil then - object_args_instance[k] = v - else - error(('cannot found "%s" field ("%s" value) ' .. - 'within allowed fields'):format(tostring(k), - tostring(v))) - end - end - local objs = accessor:select(parent, - c.destination_collection, from, - object_args_instance, list_args_instance, extra) - assert(type(objs) == 'table', - 'objs list received from an accessor ' .. - 'must be a table, got ' .. type(objs)) - if c.type == '1:1' or c.type == '1:1*' then - -- we expect here exactly one object even for 1:1* - -- connections because we processed all-parts-are-null - -- situation above - assert(#objs == 1, - 'expect one matching object, got ' .. - tostring(#objs)) - return objs[1] - else -- c.type == '1:N' - return objs - end - end, - } + fields[c.name] = convert_connection_to_field(state, c, collection_name) end -- create gql type diff --git a/graphql/utils.lua b/graphql/utils.lua index 8b8ba8a..681821b 100644 --- a/graphql/utils.lua +++ b/graphql/utils.lua @@ -147,4 +147,27 @@ function utils.optional_require(module_name) return ok and module or nil end +--- @return `table` with all keys of the given table +function utils.get_keys(table) + local keys = {} + for k, _ in pairs(table) do + keys[#keys + 1] = k + end + return keys +end + +--- Check if passed table has passed keys with non-nil values. +--- @tparam table table to check +--- @tparam table keys array of keys to check +--- @return[1] `true` if passed table has passed keys +--- @return[2] `false` otherwise +function utils.do_have_keys(table, keys) + for _, k in pairs(keys) do + if table[k] == nil then + return false + end + end + return true +end + return utils diff --git a/test/local/init_fail.result b/test/local/init_fail.result index ab9b013..8047498 100644 --- a/test/local/init_fail.result +++ b/test/local/init_fail.result @@ -1,2 +1,2 @@ -INIT: ok: false; err: 1:1 connection "user_connection" of collection "user_collection" has less fields than the index "user_str_num_index" has (cannot prove uniqueness of the partial index) +INIT: ok: false; err: 1:1 connection "user_connection" of collection "order_collection" has less fields than the index of "user_str_num_index" collection (cannot prove uniqueness of the partial index) INIT: ok: true; type(res): table diff --git a/test/local/union.result b/test/local/union.result new file mode 100644 index 0000000..10c942a --- /dev/null +++ b/test/local/union.result @@ -0,0 +1,64 @@ +RUN 1 {{{ +QUERY + query obtainHeroes($hero_id: String) { + hero_collection(hero_id: $hero_id) { + hero_id + hero_type + hero_connection { + ... on human_collection { + name + } + ... on starship_collection { + model + } + } + } + } +VARIABLES +--- +hero_id: hero_id_1 +... + +RESULT +--- +hero_collection: +- hero_type: human + hero_connection: + name: Luke + hero_id: hero_id_1 +... + +}}} + +RUN 2 {{{ +QUERY + query obtainHeroes($hero_id: String) { + hero_collection(hero_id: $hero_id) { + hero_id + hero_type + hero_connection { + ... on human_collection { + name + } + ... on starship_collection { + model + } + } + } + } +VARIABLES +--- +hero_id: hero_id_2 +... + +RESULT +--- +hero_collection: +- hero_type: starship + hero_connection: + model: Falcon-42 + hero_id: hero_id_2 +... + +}}} + diff --git a/test/local/union.test.lua b/test/local/union.test.lua new file mode 100755 index 0000000..4652ae0 --- /dev/null +++ b/test/local/union.test.lua @@ -0,0 +1,60 @@ +#!/usr/bin/env tarantool + +box.cfg { background = false } +local fio = require('fio') + +-- require in-repo version of graphql/ sources despite current working directory +package.path = fio.abspath(debug.getinfo(1).source:match("@?(.*/)") + :gsub('/./', '/'):gsub('/+$', '')) .. '/../../?.lua' .. ';' .. package.path + +local graphql = require('graphql') +local testdata = require('test.testdata.union_testdata') + +-- init box, upload test data and acquire metadata +-- ----------------------------------------------- + + +-- init box and data schema +testdata.init_spaces() + +-- upload test data +testdata.fill_test_data() + +-- acquire metadata +local metadata = testdata.get_test_metadata() +local schemas = metadata.schemas +local collections = metadata.collections +local service_fields = metadata.service_fields +local indexes = metadata.indexes +local utils = require('graphql.utils') + +-- build accessor and graphql schemas +-- ---------------------------------- +local accessor = utils.show_trace(function() + return graphql.accessor_space.new({ + schemas = schemas, + collections = collections, + service_fields = service_fields, + indexes = indexes, + }) +end) + +local gql_wrapper = utils.show_trace(function() + return graphql.new({ + schemas = schemas, + collections = collections, + accessor = accessor, + }) +end) + +-- run queries +-- ----------- + +testdata.run_queries(gql_wrapper) + +-- clean up +-- -------- + +testdata.drop_spaces() + +os.exit() \ No newline at end of file diff --git a/test/testdata/compound_index_testdata.lua b/test/testdata/compound_index_testdata.lua index 75ed408..400c166 100644 --- a/test/testdata/compound_index_testdata.lua +++ b/test/testdata/compound_index_testdata.lua @@ -1,9 +1,11 @@ local json = require('json') -local yaml = require('yaml') local utils = require('graphql.utils') +local test_utils = require('test.utils') local compound_index_testdata = {} +local format_result = test_utils.format_result + -- return an error w/o file name and line number local function strip_error(err) return tostring(err):gsub('^.-:.-: (.*)$', '%1') @@ -14,11 +16,6 @@ local function print_and_return(...) return table.concat({...}, ' ') .. '\n' end -local function format_result(name, query, variables, result) - return ('RUN %s {{{\nQUERY\n%s\nVARIABLES\n%s\nRESULT\n%s\n}}}\n'):format( - name, query:rstrip(), yaml.encode(variables), yaml.encode(result)) -end - -- schemas and meta-information -- ---------------------------- diff --git a/test/testdata/union_testdata.lua b/test/testdata/union_testdata.lua new file mode 100644 index 0000000..a98eb24 --- /dev/null +++ b/test/testdata/union_testdata.lua @@ -0,0 +1,216 @@ +local json = require('json') +local utils = require('graphql.utils') +local test_utils = require('test.utils') + +local union_testdata = {} + +function union_testdata.get_test_metadata() + local schemas = json.decode([[{ + "hero": { + "name": "hero", + "type": "record", + "fields": [ + { "name": "hero_id", "type": "string" }, + { "name": "hero_type", "type" : "string" } + ] + }, + "human": { + "name": "human", + "type": "record", + "fields": [ + { "name": "hero_id", "type": "string" }, + { "name": "name", "type": "string" }, + { "name": "episode", "type": "string"} + ] + }, + "starship": { + "name": "starship", + "type": "record", + "fields": [ + { "name": "hero_id", "type": "string" }, + { "name": "model", "type": "string" }, + { "name": "episode", "type": "string"} + ] + } + }]]) + + local collections = json.decode([[{ + "hero_collection": { + "schema_name": "hero", + "connections": [ + { + "name": "hero_connection", + "type": "1:1", + "variants": [ + { + "determinant": {"hero_type": "human"}, + "destination_collection": "human_collection", + "parts": [ + { + "source_field": "hero_id", + "destination_field": "hero_id" + } + ], + "index_name": "human_id_index" + }, + { + "determinant": {"hero_type": "starship"}, + "destination_collection": "starship_collection", + "parts": [ + { + "source_field": "hero_id", + "destination_field": "hero_id" + } + ], + "index_name": "starship_id_index" + } + ] + } + ] + }, + "human_collection": { + "schema_name": "human", + "connections": [] + }, + "starship_collection": { + "schema_name": "starship", + "connections": [] + } + }]]) + + local service_fields = { + hero = { + { name = 'expires_on', type = 'long', default = 0 }, + }, + human = { + { name = 'expires_on', type = 'long', default = 0 }, + }, + starship = { + { name = 'expires_on', type = 'long', default = 0 }, + } + } + + local indexes = { + hero_collection = { + hero_id_index = { + service_fields = {}, + fields = { 'hero_id' }, + index_type = 'tree', + unique = true, + primary = true, + }, + }, + + human_collection = { + human_id_index = { + service_fields = {}, + fields = { 'hero_id' }, + index_type = 'tree', + unique = true, + primary = true, + }, + }, + + starship_collection = { + starship_id_index = { + service_fields = {}, + fields = { 'hero_id' }, + index_type = 'tree', + unique = true, + primary = true, + }, + } + } + + return { + schemas = schemas, + collections = collections, + service_fields = service_fields, + indexes = indexes, + } +end + +function union_testdata.init_spaces() + local ID_FIELD_NUM = 2 + + box.once('test_space_init_spaces', function() + box.schema.create_space('hero_collection') + box.space.hero_collection:create_index('hero_id_index', + { type = 'tree', unique = true, parts = { ID_FIELD_NUM, 'string' }} + ) + + box.schema.create_space('human_collection') + box.space.human_collection:create_index('human_id_index', + { type = 'tree', unique = true, parts = { ID_FIELD_NUM, 'string' }} + ) + + box.schema.create_space('starship_collection') + box.space.starship_collection:create_index('starship_id_index', + { type = 'tree', unique = true, parts = { ID_FIELD_NUM, 'string' }} + ) + end) +end + +function union_testdata.fill_test_data(shard) + local shard = shard or box.space + + shard.hero_collection:replace( + { 1827767717, 'hero_id_1', 'human'}) + shard.hero_collection:replace( + { 1827767717, 'hero_id_2', 'starship'}) + + shard.human_collection:replace( + { 1827767717, 'hero_id_1', 'Luke', "EMPR"}) + + shard.starship_collection:replace( + { 1827767717, 'hero_id_2', 'Falcon-42', "NEW"}) +end + +function union_testdata.drop_spaces() + box.space._schema:delete('oncetest_space_init_spaces') + box.space.human_collection:drop() + box.space.starship_collection:drop() + box.space.hero_collection:drop() +end + +function union_testdata.run_queries(gql_wrapper) + local results = '' + + local query = [[ + query obtainHeroes($hero_id: String) { + hero_collection(hero_id: $hero_id) { + hero_id + hero_type + hero_connection { + ... on human_collection { + name + } + ... on starship_collection { + model + } + } + } + } + ]] + + local gql_query = gql_wrapper:compile(query) + + utils.show_trace(function() + local variables_1 = {hero_id = 'hero_id_1'} + local result = gql_query:execute(variables_1) + results = results .. test_utils.print_and_return(test_utils.format_result( + '1', query, variables_1, result)) + + end) + + utils.show_trace(function() + local variables_2 = {hero_id = 'hero_id_2'} + local result = gql_query:execute(variables_2) + results = results .. test_utils.print_and_return(test_utils.format_result( + '2', query, variables_2, result)) + end) + + return results +end + +return union_testdata diff --git a/test/utils.lua b/test/utils.lua new file mode 100644 index 0000000..b045d4c --- /dev/null +++ b/test/utils.lua @@ -0,0 +1,17 @@ +--- Various utility function used across the graphql module tests. + +local yaml = require('yaml') + +local utils = {} + +function utils.format_result(name, query, variables, result) + return ('RUN %s {{{\nQUERY\n%s\nVARIABLES\n%s\nRESULT\n%s\n}}}\n'):format( + name, query:rstrip(), yaml.encode(variables), yaml.encode(result)) +end + +function utils.print_and_return(...) + print(...) + return table.concat({ ... }, ' ') .. '\n' +end + +return utils