Skip to content

Bugs/directive input object type #384

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Apr 5, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -55,11 +55,11 @@ A few libraries exist to ease the boilerplate pain, including [GraphQL-Java's bu
<dependency>
<groupId>com.graphql-java-kickstart</groupId>
<artifactId>graphql-java-tools</artifactId>
<version>6.0.0</version>
<version>6.0.2</version>
</dependency>
```
```groovy
compile 'com.graphql-java-kickstart:graphql-java-tools:6.0.0'
compile 'com.graphql-java-kickstart:graphql-java-tools:6.0.2'
```

New releases will be available faster in the JCenter repository than in Maven Central. Add the following to use for Maven
Expand Down
2 changes: 1 addition & 1 deletion pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

<groupId>com.graphql-java-kickstart</groupId>
<artifactId>graphql-java-tools</artifactId>
<version>6.0.2-SNAPSHOT</version>
<version>6.0.3-SNAPSHOT</version>
<packaging>jar</packaging>

<name>GraphQL Java Tools</name>
Expand Down
86 changes: 66 additions & 20 deletions src/main/kotlin/graphql/kickstart/tools/SchemaParser.kt
Original file line number Diff line number Diff line change
Expand Up @@ -64,10 +64,16 @@ class SchemaParser internal constructor(scanResult: ScannedSchemaObjects, privat
fun parseSchemaObjects(): SchemaObjects {

// Create GraphQL objects
val interfaces = interfaceDefinitions.map { createInterfaceObject(it) }
val objects = objectDefinitions.map { createObject(it, interfaces) }
// val inputObjects = inputObjectDefinitions.map { createInputObject(it, listOf())}
val inputObjects: MutableList<GraphQLInputObjectType> = mutableListOf()
inputObjectDefinitions.forEach {
if (inputObjects.none { io -> io.name == it.name }) {
inputObjects.add(createInputObject(it, inputObjects))
}
}
val interfaces = interfaceDefinitions.map { createInterfaceObject(it, inputObjects) }
val objects = objectDefinitions.map { createObject(it, interfaces, inputObjects) }
val unions = unionDefinitions.map { createUnionObject(it, objects) }
val inputObjects = inputObjectDefinitions.map { createInputObject(it) }
val enums = enumDefinitions.map { createEnumObject(it) }

// Assign type resolver to interfaces now that we know all of the object types
Expand Down Expand Up @@ -103,7 +109,7 @@ class SchemaParser internal constructor(scanResult: ScannedSchemaObjects, privat
@Suppress("unused")
fun getUnusedDefinitions(): Set<TypeDefinition<*>> = unusedDefinitions

private fun createObject(objectDefinition: ObjectTypeDefinition, interfaces: List<GraphQLInterfaceType>): GraphQLObjectType {
private fun createObject(objectDefinition: ObjectTypeDefinition, interfaces: List<GraphQLInterfaceType>, inputObjects: List<GraphQLInputObjectType>): GraphQLObjectType {
val name = objectDefinition.name
val builder = GraphQLObjectType.newObject()
.name(name)
Expand All @@ -121,7 +127,7 @@ class SchemaParser internal constructor(scanResult: ScannedSchemaObjects, privat
objectDefinition.getExtendedFieldDefinitions(extensionDefinitions).forEach { fieldDefinition ->
fieldDefinition.description
builder.field { field ->
createField(field, fieldDefinition)
createField(field, fieldDefinition, inputObjects)
codeRegistryBuilder.dataFetcher(
FieldCoordinates.coordinates(objectDefinition.name, fieldDefinition.name),
fieldResolversByType[objectDefinition]?.get(fieldDefinition)?.createDataFetcher()
Expand Down Expand Up @@ -153,7 +159,7 @@ class SchemaParser internal constructor(scanResult: ScannedSchemaObjects, privat
return output.toTypedArray()
}

private fun createInputObject(definition: InputObjectTypeDefinition): GraphQLInputObjectType {
private fun createInputObject(definition: InputObjectTypeDefinition, inputObjects: List<GraphQLInputObjectType>): GraphQLInputObjectType {
val builder = GraphQLInputObjectType.newInputObject()
.name(definition.name)
.definition(definition)
Expand All @@ -167,7 +173,7 @@ class SchemaParser internal constructor(scanResult: ScannedSchemaObjects, privat
.definition(inputDefinition)
.description(if (inputDefinition.description != null) inputDefinition.description.content else getDocumentation(inputDefinition))
.defaultValue(buildDefaultValue(inputDefinition.defaultValue))
.type(determineInputType(inputDefinition.type))
.type(determineInputType(inputDefinition.type, inputObjects))
.withDirectives(*buildDirectives(inputDefinition.directives, setOf(), Introspection.DirectiveLocation.INPUT_FIELD_DEFINITION))
builder.field(fieldBuilder.build())
}
Expand Down Expand Up @@ -210,7 +216,7 @@ class SchemaParser internal constructor(scanResult: ScannedSchemaObjects, privat
return directiveGenerator.onEnum(builder.build(), DirectiveBehavior.Params(runtimeWiring, codeRegistryBuilder))
}

private fun createInterfaceObject(interfaceDefinition: InterfaceTypeDefinition): GraphQLInterfaceType {
private fun createInterfaceObject(interfaceDefinition: InterfaceTypeDefinition, inputObjects: List<GraphQLInputObjectType>): GraphQLInterfaceType {
val name = interfaceDefinition.name
val builder = GraphQLInterfaceType.newInterface()
.name(name)
Expand All @@ -220,7 +226,7 @@ class SchemaParser internal constructor(scanResult: ScannedSchemaObjects, privat
builder.withDirectives(*buildDirectives(interfaceDefinition.directives, setOf(), Introspection.DirectiveLocation.INTERFACE))

interfaceDefinition.fieldDefinitions.forEach { fieldDefinition ->
builder.field { field -> createField(field, fieldDefinition) }
builder.field { field -> createField(field, fieldDefinition, inputObjects) }
}

return directiveGenerator.onInterface(builder.build(), DirectiveBehavior.Params(runtimeWiring, codeRegistryBuilder))
Expand Down Expand Up @@ -259,19 +265,19 @@ class SchemaParser internal constructor(scanResult: ScannedSchemaObjects, privat
return leafObjects
}

private fun createField(field: GraphQLFieldDefinition.Builder, fieldDefinition: FieldDefinition): GraphQLFieldDefinition.Builder {
private fun createField(field: GraphQLFieldDefinition.Builder, fieldDefinition: FieldDefinition, inputObjects: List<GraphQLInputObjectType>): GraphQLFieldDefinition.Builder {
field.name(fieldDefinition.name)
field.description(if (fieldDefinition.description != null) fieldDefinition.description.content else getDocumentation(fieldDefinition))
field.definition(fieldDefinition)
getDeprecated(fieldDefinition.directives)?.let { field.deprecate(it) }
field.type(determineOutputType(fieldDefinition.type))
field.type(determineOutputType(fieldDefinition.type, inputObjects))
fieldDefinition.inputValueDefinitions.forEach { argumentDefinition ->
val argumentBuilder = GraphQLArgument.newArgument()
.name(argumentDefinition.name)
.definition(argumentDefinition)
.description(if (argumentDefinition.description != null) argumentDefinition.description.content else getDocumentation(argumentDefinition))
.defaultValue(buildDefaultValue(argumentDefinition.defaultValue))
.type(determineInputType(argumentDefinition.type))
.type(determineInputType(argumentDefinition.type, inputObjects))
.withDirectives(*buildDirectives(argumentDefinition.directives, setOf(), Introspection.DirectiveLocation.ARGUMENT_DEFINITION))
field.argument(argumentBuilder.build())
}
Expand All @@ -293,16 +299,17 @@ class SchemaParser internal constructor(scanResult: ScannedSchemaObjects, privat
}
}

private fun determineOutputType(typeDefinition: Type<*>) =
determineType(GraphQLOutputType::class, typeDefinition, permittedTypesForObject) as GraphQLOutputType
private fun determineOutputType(typeDefinition: Type<*>, inputObjects: List<GraphQLInputObjectType>) =
determineType(GraphQLOutputType::class, typeDefinition, permittedTypesForObject, inputObjects) as GraphQLOutputType

private fun determineInputType(typeDefinition: Type<*>) =
determineType(GraphQLInputType::class, typeDefinition, permittedTypesForInputObject) as GraphQLInputType

private fun <T : Any> determineType(expectedType: KClass<T>, typeDefinition: Type<*>, allowedTypeReferences: Set<String>): GraphQLType =
private fun <T : Any> determineType(expectedType: KClass<T>, typeDefinition: Type<*>, allowedTypeReferences: Set<String>, inputObjects: List<GraphQLInputObjectType>): GraphQLType =
when (typeDefinition) {
is ListType -> GraphQLList(determineType(expectedType, typeDefinition.type, allowedTypeReferences))
is NonNullType -> GraphQLNonNull(determineType(expectedType, typeDefinition.type, allowedTypeReferences))
is ListType -> GraphQLList(determineType(expectedType, typeDefinition.type, allowedTypeReferences, inputObjects))
is NonNullType -> GraphQLNonNull(determineType(expectedType, typeDefinition.type, allowedTypeReferences, inputObjects))
is InputObjectTypeDefinition -> {
log.info("Create input object")
createInputObject(typeDefinition, inputObjects)
}
is TypeName -> {
val scalarType = customScalars[typeDefinition.name] ?: graphQLScalars[typeDefinition.name]
if (scalarType != null) {
Expand All @@ -318,6 +325,45 @@ class SchemaParser internal constructor(scanResult: ScannedSchemaObjects, privat
else -> throw SchemaError("Unknown type: $typeDefinition")
}

private fun determineInputType(typeDefinition: Type<*>, inputObjects: List<GraphQLInputObjectType>) =
determineInputType(GraphQLInputType::class, typeDefinition, permittedTypesForInputObject, inputObjects) as GraphQLInputType

private fun <T : Any> determineInputType(expectedType: KClass<T>, typeDefinition: Type<*>, allowedTypeReferences: Set<String>, inputObjects: List<GraphQLInputObjectType>): GraphQLType =
when (typeDefinition) {
is ListType -> GraphQLList(determineType(expectedType, typeDefinition.type, allowedTypeReferences, inputObjects))
is NonNullType -> GraphQLNonNull(determineType(expectedType, typeDefinition.type, allowedTypeReferences, inputObjects))
is InputObjectTypeDefinition -> {
log.info("Create input object")
createInputObject(typeDefinition, inputObjects)
}
is TypeName -> {
val scalarType = customScalars[typeDefinition.name] ?: graphQLScalars[typeDefinition.name]
if (scalarType != null) {
scalarType
} else {
if (!allowedTypeReferences.contains(typeDefinition.name)) {
throw SchemaError("Expected type '${typeDefinition.name}' to be a ${expectedType.simpleName}, but it wasn't! " +
"Was a type only permitted for object types incorrectly used as an input type, or vice-versa?")
}
val found = inputObjects.filter { it.name == typeDefinition.name }
if (found.size == 1) {
found[0]
} else {
val filteredDefinitions = inputObjectDefinitions.filter { it.name == typeDefinition.name }
if (filteredDefinitions.isNotEmpty()) {
val inputObject = createInputObject(filteredDefinitions[0], inputObjects)
(inputObjects as MutableList).add(inputObject)
inputObject
} else {
// todo: handle enum type
GraphQLTypeReference(typeDefinition.name)
}
}
}
}
else -> throw SchemaError("Unknown type: $typeDefinition")
}

/**
* Returns an optional [String] describing a deprecated field/enum.
* If a deprecation directive was defined using the @deprecated directive,
Expand Down
Loading