Skip to content

Commit 7b0e758

Browse files
authored
Merge pull request #384 from graphql-java-kickstart/bugs/directive-input-object-type
Bugs/directive input object type
2 parents 31392b1 + 10ee43f commit 7b0e758

File tree

3 files changed

+112
-42
lines changed

3 files changed

+112
-42
lines changed

README.md

+2-2
Original file line numberDiff line numberDiff line change
@@ -55,11 +55,11 @@ A few libraries exist to ease the boilerplate pain, including [GraphQL-Java's bu
5555
<dependency>
5656
<groupId>com.graphql-java-kickstart</groupId>
5757
<artifactId>graphql-java-tools</artifactId>
58-
<version>6.0.0</version>
58+
<version>6.0.2</version>
5959
</dependency>
6060
```
6161
```groovy
62-
compile 'com.graphql-java-kickstart:graphql-java-tools:6.0.0'
62+
compile 'com.graphql-java-kickstart:graphql-java-tools:6.0.2'
6363
```
6464

6565
New releases will be available faster in the JCenter repository than in Maven Central. Add the following to use for Maven

src/main/kotlin/graphql/kickstart/tools/SchemaParser.kt

+66-20
Original file line numberDiff line numberDiff line change
@@ -64,10 +64,16 @@ class SchemaParser internal constructor(scanResult: ScannedSchemaObjects, privat
6464
fun parseSchemaObjects(): SchemaObjects {
6565

6666
// Create GraphQL objects
67-
val interfaces = interfaceDefinitions.map { createInterfaceObject(it) }
68-
val objects = objectDefinitions.map { createObject(it, interfaces) }
67+
// val inputObjects = inputObjectDefinitions.map { createInputObject(it, listOf())}
68+
val inputObjects: MutableList<GraphQLInputObjectType> = mutableListOf()
69+
inputObjectDefinitions.forEach {
70+
if (inputObjects.none { io -> io.name == it.name }) {
71+
inputObjects.add(createInputObject(it, inputObjects))
72+
}
73+
}
74+
val interfaces = interfaceDefinitions.map { createInterfaceObject(it, inputObjects) }
75+
val objects = objectDefinitions.map { createObject(it, interfaces, inputObjects) }
6976
val unions = unionDefinitions.map { createUnionObject(it, objects) }
70-
val inputObjects = inputObjectDefinitions.map { createInputObject(it) }
7177
val enums = enumDefinitions.map { createEnumObject(it) }
7278

7379
// Assign type resolver to interfaces now that we know all of the object types
@@ -103,7 +109,7 @@ class SchemaParser internal constructor(scanResult: ScannedSchemaObjects, privat
103109
@Suppress("unused")
104110
fun getUnusedDefinitions(): Set<TypeDefinition<*>> = unusedDefinitions
105111

106-
private fun createObject(objectDefinition: ObjectTypeDefinition, interfaces: List<GraphQLInterfaceType>): GraphQLObjectType {
112+
private fun createObject(objectDefinition: ObjectTypeDefinition, interfaces: List<GraphQLInterfaceType>, inputObjects: List<GraphQLInputObjectType>): GraphQLObjectType {
107113
val name = objectDefinition.name
108114
val builder = GraphQLObjectType.newObject()
109115
.name(name)
@@ -121,7 +127,7 @@ class SchemaParser internal constructor(scanResult: ScannedSchemaObjects, privat
121127
objectDefinition.getExtendedFieldDefinitions(extensionDefinitions).forEach { fieldDefinition ->
122128
fieldDefinition.description
123129
builder.field { field ->
124-
createField(field, fieldDefinition)
130+
createField(field, fieldDefinition, inputObjects)
125131
codeRegistryBuilder.dataFetcher(
126132
FieldCoordinates.coordinates(objectDefinition.name, fieldDefinition.name),
127133
fieldResolversByType[objectDefinition]?.get(fieldDefinition)?.createDataFetcher()
@@ -153,7 +159,7 @@ class SchemaParser internal constructor(scanResult: ScannedSchemaObjects, privat
153159
return output.toTypedArray()
154160
}
155161

156-
private fun createInputObject(definition: InputObjectTypeDefinition): GraphQLInputObjectType {
162+
private fun createInputObject(definition: InputObjectTypeDefinition, inputObjects: List<GraphQLInputObjectType>): GraphQLInputObjectType {
157163
val builder = GraphQLInputObjectType.newInputObject()
158164
.name(definition.name)
159165
.definition(definition)
@@ -167,7 +173,7 @@ class SchemaParser internal constructor(scanResult: ScannedSchemaObjects, privat
167173
.definition(inputDefinition)
168174
.description(if (inputDefinition.description != null) inputDefinition.description.content else getDocumentation(inputDefinition))
169175
.defaultValue(buildDefaultValue(inputDefinition.defaultValue))
170-
.type(determineInputType(inputDefinition.type))
176+
.type(determineInputType(inputDefinition.type, inputObjects))
171177
.withDirectives(*buildDirectives(inputDefinition.directives, setOf(), Introspection.DirectiveLocation.INPUT_FIELD_DEFINITION))
172178
builder.field(fieldBuilder.build())
173179
}
@@ -210,7 +216,7 @@ class SchemaParser internal constructor(scanResult: ScannedSchemaObjects, privat
210216
return directiveGenerator.onEnum(builder.build(), DirectiveBehavior.Params(runtimeWiring, codeRegistryBuilder))
211217
}
212218

213-
private fun createInterfaceObject(interfaceDefinition: InterfaceTypeDefinition): GraphQLInterfaceType {
219+
private fun createInterfaceObject(interfaceDefinition: InterfaceTypeDefinition, inputObjects: List<GraphQLInputObjectType>): GraphQLInterfaceType {
214220
val name = interfaceDefinition.name
215221
val builder = GraphQLInterfaceType.newInterface()
216222
.name(name)
@@ -220,7 +226,7 @@ class SchemaParser internal constructor(scanResult: ScannedSchemaObjects, privat
220226
builder.withDirectives(*buildDirectives(interfaceDefinition.directives, setOf(), Introspection.DirectiveLocation.INTERFACE))
221227

222228
interfaceDefinition.fieldDefinitions.forEach { fieldDefinition ->
223-
builder.field { field -> createField(field, fieldDefinition) }
229+
builder.field { field -> createField(field, fieldDefinition, inputObjects) }
224230
}
225231

226232
return directiveGenerator.onInterface(builder.build(), DirectiveBehavior.Params(runtimeWiring, codeRegistryBuilder))
@@ -259,19 +265,19 @@ class SchemaParser internal constructor(scanResult: ScannedSchemaObjects, privat
259265
return leafObjects
260266
}
261267

262-
private fun createField(field: GraphQLFieldDefinition.Builder, fieldDefinition: FieldDefinition): GraphQLFieldDefinition.Builder {
268+
private fun createField(field: GraphQLFieldDefinition.Builder, fieldDefinition: FieldDefinition, inputObjects: List<GraphQLInputObjectType>): GraphQLFieldDefinition.Builder {
263269
field.name(fieldDefinition.name)
264270
field.description(if (fieldDefinition.description != null) fieldDefinition.description.content else getDocumentation(fieldDefinition))
265271
field.definition(fieldDefinition)
266272
getDeprecated(fieldDefinition.directives)?.let { field.deprecate(it) }
267-
field.type(determineOutputType(fieldDefinition.type))
273+
field.type(determineOutputType(fieldDefinition.type, inputObjects))
268274
fieldDefinition.inputValueDefinitions.forEach { argumentDefinition ->
269275
val argumentBuilder = GraphQLArgument.newArgument()
270276
.name(argumentDefinition.name)
271277
.definition(argumentDefinition)
272278
.description(if (argumentDefinition.description != null) argumentDefinition.description.content else getDocumentation(argumentDefinition))
273279
.defaultValue(buildDefaultValue(argumentDefinition.defaultValue))
274-
.type(determineInputType(argumentDefinition.type))
280+
.type(determineInputType(argumentDefinition.type, inputObjects))
275281
.withDirectives(*buildDirectives(argumentDefinition.directives, setOf(), Introspection.DirectiveLocation.ARGUMENT_DEFINITION))
276282
field.argument(argumentBuilder.build())
277283
}
@@ -293,16 +299,17 @@ class SchemaParser internal constructor(scanResult: ScannedSchemaObjects, privat
293299
}
294300
}
295301

296-
private fun determineOutputType(typeDefinition: Type<*>) =
297-
determineType(GraphQLOutputType::class, typeDefinition, permittedTypesForObject) as GraphQLOutputType
302+
private fun determineOutputType(typeDefinition: Type<*>, inputObjects: List<GraphQLInputObjectType>) =
303+
determineType(GraphQLOutputType::class, typeDefinition, permittedTypesForObject, inputObjects) as GraphQLOutputType
298304

299-
private fun determineInputType(typeDefinition: Type<*>) =
300-
determineType(GraphQLInputType::class, typeDefinition, permittedTypesForInputObject) as GraphQLInputType
301-
302-
private fun <T : Any> determineType(expectedType: KClass<T>, typeDefinition: Type<*>, allowedTypeReferences: Set<String>): GraphQLType =
305+
private fun <T : Any> determineType(expectedType: KClass<T>, typeDefinition: Type<*>, allowedTypeReferences: Set<String>, inputObjects: List<GraphQLInputObjectType>): GraphQLType =
303306
when (typeDefinition) {
304-
is ListType -> GraphQLList(determineType(expectedType, typeDefinition.type, allowedTypeReferences))
305-
is NonNullType -> GraphQLNonNull(determineType(expectedType, typeDefinition.type, allowedTypeReferences))
307+
is ListType -> GraphQLList(determineType(expectedType, typeDefinition.type, allowedTypeReferences, inputObjects))
308+
is NonNullType -> GraphQLNonNull(determineType(expectedType, typeDefinition.type, allowedTypeReferences, inputObjects))
309+
is InputObjectTypeDefinition -> {
310+
log.info("Create input object")
311+
createInputObject(typeDefinition, inputObjects)
312+
}
306313
is TypeName -> {
307314
val scalarType = customScalars[typeDefinition.name] ?: graphQLScalars[typeDefinition.name]
308315
if (scalarType != null) {
@@ -318,6 +325,45 @@ class SchemaParser internal constructor(scanResult: ScannedSchemaObjects, privat
318325
else -> throw SchemaError("Unknown type: $typeDefinition")
319326
}
320327

328+
private fun determineInputType(typeDefinition: Type<*>, inputObjects: List<GraphQLInputObjectType>) =
329+
determineInputType(GraphQLInputType::class, typeDefinition, permittedTypesForInputObject, inputObjects) as GraphQLInputType
330+
331+
private fun <T : Any> determineInputType(expectedType: KClass<T>, typeDefinition: Type<*>, allowedTypeReferences: Set<String>, inputObjects: List<GraphQLInputObjectType>): GraphQLType =
332+
when (typeDefinition) {
333+
is ListType -> GraphQLList(determineType(expectedType, typeDefinition.type, allowedTypeReferences, inputObjects))
334+
is NonNullType -> GraphQLNonNull(determineType(expectedType, typeDefinition.type, allowedTypeReferences, inputObjects))
335+
is InputObjectTypeDefinition -> {
336+
log.info("Create input object")
337+
createInputObject(typeDefinition, inputObjects)
338+
}
339+
is TypeName -> {
340+
val scalarType = customScalars[typeDefinition.name] ?: graphQLScalars[typeDefinition.name]
341+
if (scalarType != null) {
342+
scalarType
343+
} else {
344+
if (!allowedTypeReferences.contains(typeDefinition.name)) {
345+
throw SchemaError("Expected type '${typeDefinition.name}' to be a ${expectedType.simpleName}, but it wasn't! " +
346+
"Was a type only permitted for object types incorrectly used as an input type, or vice-versa?")
347+
}
348+
val found = inputObjects.filter { it.name == typeDefinition.name }
349+
if (found.size == 1) {
350+
found[0]
351+
} else {
352+
val filteredDefinitions = inputObjectDefinitions.filter { it.name == typeDefinition.name }
353+
if (filteredDefinitions.isNotEmpty()) {
354+
val inputObject = createInputObject(filteredDefinitions[0], inputObjects)
355+
(inputObjects as MutableList).add(inputObject)
356+
inputObject
357+
} else {
358+
// todo: handle enum type
359+
GraphQLTypeReference(typeDefinition.name)
360+
}
361+
}
362+
}
363+
}
364+
else -> throw SchemaError("Unknown type: $typeDefinition")
365+
}
366+
321367
/**
322368
* Returns an optional [String] describing a deprecated field/enum.
323369
* If a deprecation directive was defined using the @deprecated directive,

0 commit comments

Comments
 (0)