Skip to content

Scan directives arguments while parsing schema #764

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Sep 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 9 additions & 8 deletions src/main/kotlin/graphql/kickstart/tools/SchemaClassScanner.kt
Original file line number Diff line number Diff line change
Expand Up @@ -149,10 +149,11 @@ internal class SchemaClassScanner(
?: error("No ${TypeDefinition::class.java.simpleName} for type name $inputTypeName")
when (typeDefinition) {
is ScalarTypeDefinition -> handleFoundScalarType(typeDefinition)
is InputObjectTypeDefinition -> {
for (input in typeDefinition.inputValueDefinitions) {
handleDirectiveInput(input.type)
}
is EnumTypeDefinition -> handleDictionaryTypes(listOf(typeDefinition)) {
"Enum type '${it.name}' is used in a directive, but no class could be found for that type name. Please pass a class for type '${it.name}' in the parser's dictionary."
}
is InputObjectTypeDefinition -> handleDictionaryTypes(listOf(typeDefinition)) {
"Input object type '${it.name}' is used in a directive, but no class could be found for that type name. Please pass a class for type '${it.name}' in the parser's dictionary."
}
}
}
Expand Down Expand Up @@ -209,9 +210,9 @@ internal class SchemaClassScanner(
log.warn("Schema type was defined but can never be accessed, and can be safely deleted: ${definition.name}")
}

val fieldResolvers = fieldResolversByType.flatMap { it.value.map { it.value } }
val observedNormalResolverInfos = fieldResolvers.map { it.resolverInfo }.distinct().filterIsInstance<NormalResolverInfo>()
val observedMultiResolverInfos = fieldResolvers.map { it.resolverInfo }.distinct().filterIsInstance<MultiResolverInfo>().flatMap { it.resolverInfoList }
val fieldResolvers = fieldResolversByType.flatMap { entry -> entry.value.map { it.value } }
val observedNormalResolverInfos = fieldResolvers.map { it.resolverInfo }.filterIsInstance<NormalResolverInfo>().toSet()
val observedMultiResolverInfos = fieldResolvers.map { it.resolverInfo }.filterIsInstance<MultiResolverInfo>().flatMap { it.resolverInfoList }.toSet()

(resolverInfos - observedNormalResolverInfos - observedMultiResolverInfos).forEach { resolverInfo ->
log.warn("Resolver was provided but no methods on it were used in data fetchers, and can be safely deleted: ${resolverInfo.resolver}")
Expand Down Expand Up @@ -255,7 +256,7 @@ internal class SchemaClassScanner(
}.flatten().distinct()
}

private fun handleDictionaryTypes(types: List<ObjectTypeDefinition>, failureMessage: (ObjectTypeDefinition) -> String) {
private fun handleDictionaryTypes(types: List<TypeDefinition<*>>, failureMessage: (TypeDefinition<*>) -> String) {
types.forEach { type ->
val dictionaryContainsType = dictionary.filter { it.key.name == type.name }.isNotEmpty()
if (!unvalidatedTypes.contains(type) && !dictionaryContainsType) {
Expand Down
138 changes: 69 additions & 69 deletions src/main/kotlin/graphql/kickstart/tools/SchemaParser.kt
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
package graphql.kickstart.tools

import graphql.Scalars
import graphql.introspection.Introspection
import graphql.introspection.Introspection.DirectiveLocation.INPUT_FIELD_DEFINITION
import graphql.kickstart.tools.directive.DirectiveWiringHelper
Expand All @@ -9,6 +8,7 @@ import graphql.kickstart.tools.util.getExtendedFieldDefinitions
import graphql.kickstart.tools.util.unwrap
import graphql.language.*
import graphql.schema.*
import graphql.schema.idl.DirectiveInfo
import graphql.schema.idl.RuntimeWiring
import graphql.schema.idl.ScalarInfo
import graphql.schema.visibility.NoIntrospectionGraphqlFieldVisibility
Expand Down Expand Up @@ -60,6 +60,8 @@ class SchemaParser internal constructor(
private val codeRegistryBuilder = GraphQLCodeRegistry.newCodeRegistry()
private val directiveWiringHelper = DirectiveWiringHelper(options, runtimeWiring, codeRegistryBuilder, directiveDefinitions)

private lateinit var schemaDirectives : Set<GraphQLDirective>

/**
* Parses the given schema with respect to the given dictionary and returns GraphQL objects.
*/
Expand All @@ -72,6 +74,7 @@ class SchemaParser internal constructor(

// Create GraphQL objects
val inputObjects: MutableList<GraphQLInputObjectType> = mutableListOf()
createDirectives(inputObjects)
inputObjectDefinitions.forEach {
if (inputObjects.none { io -> io.name == it.name }) {
inputObjects.add(createInputObject(it, inputObjects, mutableSetOf()))
Expand All @@ -82,8 +85,6 @@ class SchemaParser internal constructor(
val unions = unionDefinitions.map { createUnionObject(it, objects) }
val enums = enumDefinitions.map { createEnumObject(it) }

val directives = directiveDefinitions.map { createDirective(it, inputObjects) }.toSet()

// Assign type resolver to interfaces now that we know all of the object types
interfaces.forEach { codeRegistryBuilder.typeResolver(it, InterfaceTypeResolver(dictionary.inverse(), it)) }
unions.forEach { codeRegistryBuilder.typeResolver(it, UnionTypeResolver(dictionary.inverse(), it)) }
Expand All @@ -103,7 +104,7 @@ class SchemaParser internal constructor(
val additionalObjects = objects.filter { o -> o != query && o != subscription && o != mutation }

val types = (additionalObjects.toSet() as Set<GraphQLType>) + inputObjects + enums + interfaces + unions
return SchemaObjects(query, mutation, subscription, types, directives, codeRegistryBuilder, rootInfo.getDescription())
return SchemaObjects(query, mutation, subscription, types, schemaDirectives, codeRegistryBuilder, rootInfo.getDescription())
}

/**
Expand Down Expand Up @@ -300,44 +301,75 @@ class SchemaParser internal constructor(
.name(definition.name)
.definition(definition)
.description(getDocumentation(definition, options))
.type(determineInputType(definition.type, inputObjects, setOf()))
.type(determineInputType(definition.type, inputObjects, mutableSetOf()))
.apply { getDeprecated(definition.directives)?.let { deprecate(it) } }
.apply { definition.defaultValue?.let { defaultValueLiteral(it) } }
.withAppliedDirectives(*buildAppliedDirectives(definition.directives))
.withDirectives(*buildDirectives(definition.directives, Introspection.DirectiveLocation.ARGUMENT_DEFINITION))
.build()
}

private fun createDirective(definition: DirectiveDefinition, inputObjects: List<GraphQLInputObjectType>): GraphQLDirective {
val locations = definition.directiveLocations.map { Introspection.DirectiveLocation.valueOf(it.name) }.toTypedArray()
private fun createDirectives(inputObjects: MutableList<GraphQLInputObjectType>) {
schemaDirectives = directiveDefinitions.map { definition ->
val locations = definition.directiveLocations.map { Introspection.DirectiveLocation.valueOf(it.name) }.toTypedArray()

GraphQLDirective.newDirective()
.name(definition.name)
.description(getDocumentation(definition, options))
.definition(definition)
.comparatorRegistry(runtimeWiring.comparatorRegistry)
.validLocations(*locations)
.repeatable(definition.isRepeatable)
.apply {
definition.inputValueDefinitions.forEach { argumentDefinition ->
argument(createDirectiveArgument(argumentDefinition, inputObjects))
}
}
.build()
}.toSet()
// because the arguments can have directives too, we attach them only after the directives themselves are created
schemaDirectives = schemaDirectives.map { d ->
val arguments = d.arguments.map { a -> a.transform {
it.withAppliedDirectives(*buildAppliedDirectives(a.definition!!.directives))
.withDirectives(*buildDirectives(a.definition!!.directives, Introspection.DirectiveLocation.OBJECT))
} }
d.transform { it.replaceArguments(arguments) }
}.toSet()
}

return GraphQLDirective.newDirective()
private fun createDirectiveArgument(definition: InputValueDefinition, inputObjects: List<GraphQLInputObjectType>): GraphQLArgument {
return GraphQLArgument.newArgument()
.name(definition.name)
.description(getDocumentation(definition, options))
.definition(definition)
.comparatorRegistry(runtimeWiring.comparatorRegistry)
.validLocations(*locations)
.repeatable(definition.isRepeatable)
.apply {
definition.inputValueDefinitions.forEach { argumentDefinition ->
argument(createArgument(argumentDefinition, inputObjects))
}
}
.description(getDocumentation(definition, options))
.type(determineInputType(definition.type, inputObjects, mutableSetOf()))
.apply { getDeprecated(definition.directives)?.let { deprecate(it) } }
.apply { definition.defaultValue?.let { defaultValueLiteral(it) } }
.build()
}

private fun buildAppliedDirectives(directives: List<Directive>): Array<GraphQLAppliedDirective> {
return directives.map {
return directives.map { directive ->
val graphQLDirective = schemaDirectives.find { d -> d.name == directive.name }
?: DirectiveInfo.GRAPHQL_SPECIFICATION_DIRECTIVE_MAP[directive.name]
?: throw SchemaError("Found applied directive ${directive.name} without corresponding directive definition.")
val graphQLArguments = graphQLDirective.arguments.associateBy { it.name }

GraphQLAppliedDirective.newDirective()
.name(it.name)
.description(getDocumentation(it, options))
.name(directive.name)
.description(getDocumentation(directive, options))
.definition(directive)
.comparatorRegistry(runtimeWiring.comparatorRegistry)
.apply {
it.arguments.forEach { arg ->
directive.arguments.forEach { arg ->
val graphQLArgument = graphQLArguments[arg.name]
?: throw SchemaError("Found an unexpected directive argument ${directive.name}#${arg.name} .")
argument(GraphQLAppliedDirectiveArgument.newArgument()
.name(arg.name)
.type(buildDirectiveInputType(arg.value))
// TODO instead of guessing the type from its value, lookup the directive definition
.type(graphQLArgument.type)
.valueLiteral(arg.value)
.description(graphQLArgument.description)
.build()
)
}
Expand All @@ -358,6 +390,10 @@ class SchemaParser internal constructor(
val repeatable = directiveDefinitions.find { it.name.equals(directive.name) }?.isRepeatable ?: false
if (repeatable || !names.contains(directive.name)) {
names.add(directive.name)
val graphQLDirective = this.schemaDirectives.find { d -> d.name == directive.name }
?: DirectiveInfo.GRAPHQL_SPECIFICATION_DIRECTIVE_MAP[directive.name]
?: throw SchemaError("Found applied directive ${directive.name} without corresponding directive definition.")
val graphQLArguments = graphQLDirective.arguments.associateBy { it.name }
output.add(
GraphQLDirective.newDirective()
.name(directive.name)
Expand All @@ -367,9 +403,11 @@ class SchemaParser internal constructor(
.repeatable(repeatable)
.apply {
directive.arguments.forEach { arg ->
val graphQLArgument = graphQLArguments[arg.name]
?: throw SchemaError("Found an unexpected directive argument ${directive.name}#${arg.name}.")
argument(GraphQLArgument.newArgument()
.name(arg.name)
.type(buildDirectiveInputType(arg.value))
.type(graphQLArgument.type)
// TODO remove this once directives are fully replaced with applied directives
.valueLiteral(arg.value)
.build())
Expand All @@ -383,46 +421,6 @@ class SchemaParser internal constructor(
return output.toTypedArray()
}

private fun buildDirectiveInputType(value: Value<*>): GraphQLInputType? {
return when (value) {
is NullValue -> Scalars.GraphQLString
is FloatValue -> Scalars.GraphQLFloat
is StringValue -> Scalars.GraphQLString
is IntValue -> Scalars.GraphQLInt
is BooleanValue -> Scalars.GraphQLBoolean
is ArrayValue -> GraphQLList.list(buildDirectiveInputType(getArrayValueWrappedType(value)))
// TODO to implement this we'll need to "observe" directive's input types + match them here based on their fields(?)
else -> throw SchemaError("Directive values of type '${value::class.simpleName}' are not supported yet.")
}
}

private fun getArrayValueWrappedType(value: ArrayValue): Value<*> {
// empty array [] is equivalent to [null]
if (value.values.isEmpty()) {
return NullValue.newNullValue().build()
}

// get rid of null values
val nonNullValueList = value.values.filter { v -> v !is NullValue }

// [null, null, ...] unwrapped is null
if (nonNullValueList.isEmpty()) {
return NullValue.newNullValue().build()
}

// make sure the array isn't polymorphic
val distinctTypes = nonNullValueList
.map { it::class.java }
.distinct()

if (distinctTypes.size > 1) {
throw SchemaError("Arrays containing multiple types of values are not supported yet.")
}

// peek at first value, value exists and is assured to be non-null
return nonNullValueList[0]
}

private fun determineOutputType(typeDefinition: Type<*>, inputObjects: List<GraphQLInputObjectType>) =
determineType(GraphQLOutputType::class, typeDefinition, permittedTypesForObject, inputObjects) as GraphQLOutputType

Expand Down Expand Up @@ -455,13 +453,15 @@ class SchemaParser internal constructor(
else -> throw SchemaError("Unknown type: $typeDefinition")
}

private fun determineInputType(typeDefinition: Type<*>, inputObjects: List<GraphQLInputObjectType>, referencingInputObjects: Set<String>) =
private fun determineInputType(typeDefinition: Type<*>, inputObjects: List<GraphQLInputObjectType>, referencingInputObjects: MutableSet<String>) =
determineInputType(GraphQLInputType::class, typeDefinition, permittedTypesForInputObject, inputObjects, referencingInputObjects)

private fun <T : Any> determineInputType(expectedType: KClass<T>,
typeDefinition: Type<*>, allowedTypeReferences: Set<String>,
inputObjects: List<GraphQLInputObjectType>,
referencingInputObjects: Set<String>): GraphQLInputType =
private fun <T : Any> determineInputType(
expectedType: KClass<T>,
typeDefinition: Type<*>,
allowedTypeReferences: Set<String>,
inputObjects: List<GraphQLInputObjectType>,
referencingInputObjects: MutableSet<String>): GraphQLInputType =
when (typeDefinition) {
is ListType -> GraphQLList(determineType(expectedType, typeDefinition.type, allowedTypeReferences, inputObjects))
is NonNullType -> GraphQLNonNull(determineType(expectedType, typeDefinition.type, allowedTypeReferences, inputObjects))
Expand Down Expand Up @@ -489,7 +489,7 @@ class SchemaParser internal constructor(
if (referencingInputObject != null) {
GraphQLTypeReference(referencingInputObject)
} else {
val inputObject = createInputObject(filteredDefinitions[0], inputObjects, referencingInputObjects as MutableSet<String>)
val inputObject = createInputObject(filteredDefinitions[0], inputObjects, referencingInputObjects)
(inputObjects as MutableList).add(inputObject)
inputObject
}
Expand Down
Loading