@@ -64,10 +64,16 @@ class SchemaParser internal constructor(scanResult: ScannedSchemaObjects, privat
64
64
fun parseSchemaObjects (): SchemaObjects {
65
65
66
66
// Create GraphQL objects
67
- val interfaces = interfaceDefinitions.map { createInterfaceObject(it) }
68
- val objects = objectDefinitions.map { createObject(it, interfaces) }
67
+ // val inputObjects = inputObjectDefinitions.map { createInputObject(it, listOf())}
68
+ val inputObjects: MutableList <GraphQLInputObjectType > = mutableListOf ()
69
+ inputObjectDefinitions.forEach {
70
+ if (inputObjects.none { io -> io.name == it.name }) {
71
+ inputObjects.add(createInputObject(it, inputObjects))
72
+ }
73
+ }
74
+ val interfaces = interfaceDefinitions.map { createInterfaceObject(it, inputObjects) }
75
+ val objects = objectDefinitions.map { createObject(it, interfaces, inputObjects) }
69
76
val unions = unionDefinitions.map { createUnionObject(it, objects) }
70
- val inputObjects = inputObjectDefinitions.map { createInputObject(it) }
71
77
val enums = enumDefinitions.map { createEnumObject(it) }
72
78
73
79
// Assign type resolver to interfaces now that we know all of the object types
@@ -103,7 +109,7 @@ class SchemaParser internal constructor(scanResult: ScannedSchemaObjects, privat
103
109
@Suppress(" unused" )
104
110
fun getUnusedDefinitions (): Set <TypeDefinition <* >> = unusedDefinitions
105
111
106
- private fun createObject (objectDefinition : ObjectTypeDefinition , interfaces : List <GraphQLInterfaceType >): GraphQLObjectType {
112
+ private fun createObject (objectDefinition : ObjectTypeDefinition , interfaces : List <GraphQLInterfaceType >, inputObjects : List < GraphQLInputObjectType > ): GraphQLObjectType {
107
113
val name = objectDefinition.name
108
114
val builder = GraphQLObjectType .newObject()
109
115
.name(name)
@@ -121,7 +127,7 @@ class SchemaParser internal constructor(scanResult: ScannedSchemaObjects, privat
121
127
objectDefinition.getExtendedFieldDefinitions(extensionDefinitions).forEach { fieldDefinition ->
122
128
fieldDefinition.description
123
129
builder.field { field ->
124
- createField(field, fieldDefinition)
130
+ createField(field, fieldDefinition, inputObjects )
125
131
codeRegistryBuilder.dataFetcher(
126
132
FieldCoordinates .coordinates(objectDefinition.name, fieldDefinition.name),
127
133
fieldResolversByType[objectDefinition]?.get(fieldDefinition)?.createDataFetcher()
@@ -153,7 +159,7 @@ class SchemaParser internal constructor(scanResult: ScannedSchemaObjects, privat
153
159
return output.toTypedArray()
154
160
}
155
161
156
- private fun createInputObject (definition : InputObjectTypeDefinition ): GraphQLInputObjectType {
162
+ private fun createInputObject (definition : InputObjectTypeDefinition , inputObjects : List < GraphQLInputObjectType > ): GraphQLInputObjectType {
157
163
val builder = GraphQLInputObjectType .newInputObject()
158
164
.name(definition.name)
159
165
.definition(definition)
@@ -167,7 +173,7 @@ class SchemaParser internal constructor(scanResult: ScannedSchemaObjects, privat
167
173
.definition(inputDefinition)
168
174
.description(if (inputDefinition.description != null ) inputDefinition.description.content else getDocumentation(inputDefinition))
169
175
.defaultValue(buildDefaultValue(inputDefinition.defaultValue))
170
- .type(determineInputType(inputDefinition.type))
176
+ .type(determineInputType(inputDefinition.type, inputObjects ))
171
177
.withDirectives(* buildDirectives(inputDefinition.directives, setOf (), Introspection .DirectiveLocation .INPUT_FIELD_DEFINITION ))
172
178
builder.field(fieldBuilder.build())
173
179
}
@@ -210,7 +216,7 @@ class SchemaParser internal constructor(scanResult: ScannedSchemaObjects, privat
210
216
return directiveGenerator.onEnum(builder.build(), DirectiveBehavior .Params (runtimeWiring, codeRegistryBuilder))
211
217
}
212
218
213
- private fun createInterfaceObject (interfaceDefinition : InterfaceTypeDefinition ): GraphQLInterfaceType {
219
+ private fun createInterfaceObject (interfaceDefinition : InterfaceTypeDefinition , inputObjects : List < GraphQLInputObjectType > ): GraphQLInterfaceType {
214
220
val name = interfaceDefinition.name
215
221
val builder = GraphQLInterfaceType .newInterface()
216
222
.name(name)
@@ -220,7 +226,7 @@ class SchemaParser internal constructor(scanResult: ScannedSchemaObjects, privat
220
226
builder.withDirectives(* buildDirectives(interfaceDefinition.directives, setOf (), Introspection .DirectiveLocation .INTERFACE ))
221
227
222
228
interfaceDefinition.fieldDefinitions.forEach { fieldDefinition ->
223
- builder.field { field -> createField(field, fieldDefinition) }
229
+ builder.field { field -> createField(field, fieldDefinition, inputObjects ) }
224
230
}
225
231
226
232
return directiveGenerator.onInterface(builder.build(), DirectiveBehavior .Params (runtimeWiring, codeRegistryBuilder))
@@ -259,19 +265,19 @@ class SchemaParser internal constructor(scanResult: ScannedSchemaObjects, privat
259
265
return leafObjects
260
266
}
261
267
262
- private fun createField (field : GraphQLFieldDefinition .Builder , fieldDefinition : FieldDefinition ): GraphQLFieldDefinition .Builder {
268
+ private fun createField (field : GraphQLFieldDefinition .Builder , fieldDefinition : FieldDefinition , inputObjects : List < GraphQLInputObjectType > ): GraphQLFieldDefinition .Builder {
263
269
field.name(fieldDefinition.name)
264
270
field.description(if (fieldDefinition.description != null ) fieldDefinition.description.content else getDocumentation(fieldDefinition))
265
271
field.definition(fieldDefinition)
266
272
getDeprecated(fieldDefinition.directives)?.let { field.deprecate(it) }
267
- field.type(determineOutputType(fieldDefinition.type))
273
+ field.type(determineOutputType(fieldDefinition.type, inputObjects ))
268
274
fieldDefinition.inputValueDefinitions.forEach { argumentDefinition ->
269
275
val argumentBuilder = GraphQLArgument .newArgument()
270
276
.name(argumentDefinition.name)
271
277
.definition(argumentDefinition)
272
278
.description(if (argumentDefinition.description != null ) argumentDefinition.description.content else getDocumentation(argumentDefinition))
273
279
.defaultValue(buildDefaultValue(argumentDefinition.defaultValue))
274
- .type(determineInputType(argumentDefinition.type))
280
+ .type(determineInputType(argumentDefinition.type, inputObjects ))
275
281
.withDirectives(* buildDirectives(argumentDefinition.directives, setOf (), Introspection .DirectiveLocation .ARGUMENT_DEFINITION ))
276
282
field.argument(argumentBuilder.build())
277
283
}
@@ -293,16 +299,17 @@ class SchemaParser internal constructor(scanResult: ScannedSchemaObjects, privat
293
299
}
294
300
}
295
301
296
- private fun determineOutputType (typeDefinition : Type <* >) =
297
- determineType(GraphQLOutputType ::class , typeDefinition, permittedTypesForObject) as GraphQLOutputType
302
+ private fun determineOutputType (typeDefinition : Type <* >, inputObjects : List < GraphQLInputObjectType > ) =
303
+ determineType(GraphQLOutputType ::class , typeDefinition, permittedTypesForObject, inputObjects ) as GraphQLOutputType
298
304
299
- private fun determineInputType (typeDefinition : Type <* >) =
300
- determineType(GraphQLInputType ::class , typeDefinition, permittedTypesForInputObject) as GraphQLInputType
301
-
302
- private fun <T : Any > determineType (expectedType : KClass <T >, typeDefinition : Type <* >, allowedTypeReferences : Set <String >): GraphQLType =
305
+ private fun <T : Any > determineType (expectedType : KClass <T >, typeDefinition : Type <* >, allowedTypeReferences : Set <String >, inputObjects : List <GraphQLInputObjectType >): GraphQLType =
303
306
when (typeDefinition) {
304
- is ListType -> GraphQLList (determineType(expectedType, typeDefinition.type, allowedTypeReferences))
305
- is NonNullType -> GraphQLNonNull (determineType(expectedType, typeDefinition.type, allowedTypeReferences))
307
+ is ListType -> GraphQLList (determineType(expectedType, typeDefinition.type, allowedTypeReferences, inputObjects))
308
+ is NonNullType -> GraphQLNonNull (determineType(expectedType, typeDefinition.type, allowedTypeReferences, inputObjects))
309
+ is InputObjectTypeDefinition -> {
310
+ log.info(" Create input object" )
311
+ createInputObject(typeDefinition, inputObjects)
312
+ }
306
313
is TypeName -> {
307
314
val scalarType = customScalars[typeDefinition.name] ? : graphQLScalars[typeDefinition.name]
308
315
if (scalarType != null ) {
@@ -318,6 +325,45 @@ class SchemaParser internal constructor(scanResult: ScannedSchemaObjects, privat
318
325
else -> throw SchemaError (" Unknown type: $typeDefinition " )
319
326
}
320
327
328
+ private fun determineInputType (typeDefinition : Type <* >, inputObjects : List <GraphQLInputObjectType >) =
329
+ determineInputType(GraphQLInputType ::class , typeDefinition, permittedTypesForInputObject, inputObjects) as GraphQLInputType
330
+
331
+ private fun <T : Any > determineInputType (expectedType : KClass <T >, typeDefinition : Type <* >, allowedTypeReferences : Set <String >, inputObjects : List <GraphQLInputObjectType >): GraphQLType =
332
+ when (typeDefinition) {
333
+ is ListType -> GraphQLList (determineType(expectedType, typeDefinition.type, allowedTypeReferences, inputObjects))
334
+ is NonNullType -> GraphQLNonNull (determineType(expectedType, typeDefinition.type, allowedTypeReferences, inputObjects))
335
+ is InputObjectTypeDefinition -> {
336
+ log.info(" Create input object" )
337
+ createInputObject(typeDefinition, inputObjects)
338
+ }
339
+ is TypeName -> {
340
+ val scalarType = customScalars[typeDefinition.name] ? : graphQLScalars[typeDefinition.name]
341
+ if (scalarType != null ) {
342
+ scalarType
343
+ } else {
344
+ if (! allowedTypeReferences.contains(typeDefinition.name)) {
345
+ throw SchemaError (" Expected type '${typeDefinition.name} ' to be a ${expectedType.simpleName} , but it wasn't! " +
346
+ " Was a type only permitted for object types incorrectly used as an input type, or vice-versa?" )
347
+ }
348
+ val found = inputObjects.filter { it.name == typeDefinition.name }
349
+ if (found.size == 1 ) {
350
+ found[0 ]
351
+ } else {
352
+ val filteredDefinitions = inputObjectDefinitions.filter { it.name == typeDefinition.name }
353
+ if (filteredDefinitions.isNotEmpty()) {
354
+ val inputObject = createInputObject(filteredDefinitions[0 ], inputObjects)
355
+ (inputObjects as MutableList ).add(inputObject)
356
+ inputObject
357
+ } else {
358
+ // todo: handle enum type
359
+ GraphQLTypeReference (typeDefinition.name)
360
+ }
361
+ }
362
+ }
363
+ }
364
+ else -> throw SchemaError (" Unknown type: $typeDefinition " )
365
+ }
366
+
321
367
/* *
322
368
* Returns an optional [String] describing a deprecated field/enum.
323
369
* If a deprecation directive was defined using the @deprecated directive,
0 commit comments