diff --git a/README.md b/README.md
index c9bddb66..2206279c 100644
--- a/README.md
+++ b/README.md
@@ -55,11 +55,11 @@ A few libraries exist to ease the boilerplate pain, including [GraphQL-Java's bu
com.graphql-java-kickstart
graphql-java-tools
- 6.0.0
+ 6.0.2
```
```groovy
-compile 'com.graphql-java-kickstart:graphql-java-tools:6.0.0'
+compile 'com.graphql-java-kickstart:graphql-java-tools:6.0.2'
```
New releases will be available faster in the JCenter repository than in Maven Central. Add the following to use for Maven
diff --git a/pom.xml b/pom.xml
index 042ce41d..56dcae9b 100644
--- a/pom.xml
+++ b/pom.xml
@@ -4,7 +4,7 @@
com.graphql-java-kickstart
graphql-java-tools
- 6.0.2-SNAPSHOT
+ 6.0.3-SNAPSHOT
jar
GraphQL Java Tools
diff --git a/src/main/kotlin/graphql/kickstart/tools/SchemaParser.kt b/src/main/kotlin/graphql/kickstart/tools/SchemaParser.kt
index 42bde120..8dba70ce 100644
--- a/src/main/kotlin/graphql/kickstart/tools/SchemaParser.kt
+++ b/src/main/kotlin/graphql/kickstart/tools/SchemaParser.kt
@@ -64,10 +64,16 @@ class SchemaParser internal constructor(scanResult: ScannedSchemaObjects, privat
fun parseSchemaObjects(): SchemaObjects {
// Create GraphQL objects
- val interfaces = interfaceDefinitions.map { createInterfaceObject(it) }
- val objects = objectDefinitions.map { createObject(it, interfaces) }
+// val inputObjects = inputObjectDefinitions.map { createInputObject(it, listOf())}
+ val inputObjects: MutableList = mutableListOf()
+ inputObjectDefinitions.forEach {
+ if (inputObjects.none { io -> io.name == it.name }) {
+ inputObjects.add(createInputObject(it, inputObjects))
+ }
+ }
+ val interfaces = interfaceDefinitions.map { createInterfaceObject(it, inputObjects) }
+ val objects = objectDefinitions.map { createObject(it, interfaces, inputObjects) }
val unions = unionDefinitions.map { createUnionObject(it, objects) }
- val inputObjects = inputObjectDefinitions.map { createInputObject(it) }
val enums = enumDefinitions.map { createEnumObject(it) }
// Assign type resolver to interfaces now that we know all of the object types
@@ -103,7 +109,7 @@ class SchemaParser internal constructor(scanResult: ScannedSchemaObjects, privat
@Suppress("unused")
fun getUnusedDefinitions(): Set> = unusedDefinitions
- private fun createObject(objectDefinition: ObjectTypeDefinition, interfaces: List): GraphQLObjectType {
+ private fun createObject(objectDefinition: ObjectTypeDefinition, interfaces: List, inputObjects: List): GraphQLObjectType {
val name = objectDefinition.name
val builder = GraphQLObjectType.newObject()
.name(name)
@@ -121,7 +127,7 @@ class SchemaParser internal constructor(scanResult: ScannedSchemaObjects, privat
objectDefinition.getExtendedFieldDefinitions(extensionDefinitions).forEach { fieldDefinition ->
fieldDefinition.description
builder.field { field ->
- createField(field, fieldDefinition)
+ createField(field, fieldDefinition, inputObjects)
codeRegistryBuilder.dataFetcher(
FieldCoordinates.coordinates(objectDefinition.name, fieldDefinition.name),
fieldResolversByType[objectDefinition]?.get(fieldDefinition)?.createDataFetcher()
@@ -153,7 +159,7 @@ class SchemaParser internal constructor(scanResult: ScannedSchemaObjects, privat
return output.toTypedArray()
}
- private fun createInputObject(definition: InputObjectTypeDefinition): GraphQLInputObjectType {
+ private fun createInputObject(definition: InputObjectTypeDefinition, inputObjects: List): GraphQLInputObjectType {
val builder = GraphQLInputObjectType.newInputObject()
.name(definition.name)
.definition(definition)
@@ -167,7 +173,7 @@ class SchemaParser internal constructor(scanResult: ScannedSchemaObjects, privat
.definition(inputDefinition)
.description(if (inputDefinition.description != null) inputDefinition.description.content else getDocumentation(inputDefinition))
.defaultValue(buildDefaultValue(inputDefinition.defaultValue))
- .type(determineInputType(inputDefinition.type))
+ .type(determineInputType(inputDefinition.type, inputObjects))
.withDirectives(*buildDirectives(inputDefinition.directives, setOf(), Introspection.DirectiveLocation.INPUT_FIELD_DEFINITION))
builder.field(fieldBuilder.build())
}
@@ -210,7 +216,7 @@ class SchemaParser internal constructor(scanResult: ScannedSchemaObjects, privat
return directiveGenerator.onEnum(builder.build(), DirectiveBehavior.Params(runtimeWiring, codeRegistryBuilder))
}
- private fun createInterfaceObject(interfaceDefinition: InterfaceTypeDefinition): GraphQLInterfaceType {
+ private fun createInterfaceObject(interfaceDefinition: InterfaceTypeDefinition, inputObjects: List): GraphQLInterfaceType {
val name = interfaceDefinition.name
val builder = GraphQLInterfaceType.newInterface()
.name(name)
@@ -220,7 +226,7 @@ class SchemaParser internal constructor(scanResult: ScannedSchemaObjects, privat
builder.withDirectives(*buildDirectives(interfaceDefinition.directives, setOf(), Introspection.DirectiveLocation.INTERFACE))
interfaceDefinition.fieldDefinitions.forEach { fieldDefinition ->
- builder.field { field -> createField(field, fieldDefinition) }
+ builder.field { field -> createField(field, fieldDefinition, inputObjects) }
}
return directiveGenerator.onInterface(builder.build(), DirectiveBehavior.Params(runtimeWiring, codeRegistryBuilder))
@@ -259,19 +265,19 @@ class SchemaParser internal constructor(scanResult: ScannedSchemaObjects, privat
return leafObjects
}
- private fun createField(field: GraphQLFieldDefinition.Builder, fieldDefinition: FieldDefinition): GraphQLFieldDefinition.Builder {
+ private fun createField(field: GraphQLFieldDefinition.Builder, fieldDefinition: FieldDefinition, inputObjects: List): GraphQLFieldDefinition.Builder {
field.name(fieldDefinition.name)
field.description(if (fieldDefinition.description != null) fieldDefinition.description.content else getDocumentation(fieldDefinition))
field.definition(fieldDefinition)
getDeprecated(fieldDefinition.directives)?.let { field.deprecate(it) }
- field.type(determineOutputType(fieldDefinition.type))
+ field.type(determineOutputType(fieldDefinition.type, inputObjects))
fieldDefinition.inputValueDefinitions.forEach { argumentDefinition ->
val argumentBuilder = GraphQLArgument.newArgument()
.name(argumentDefinition.name)
.definition(argumentDefinition)
.description(if (argumentDefinition.description != null) argumentDefinition.description.content else getDocumentation(argumentDefinition))
.defaultValue(buildDefaultValue(argumentDefinition.defaultValue))
- .type(determineInputType(argumentDefinition.type))
+ .type(determineInputType(argumentDefinition.type, inputObjects))
.withDirectives(*buildDirectives(argumentDefinition.directives, setOf(), Introspection.DirectiveLocation.ARGUMENT_DEFINITION))
field.argument(argumentBuilder.build())
}
@@ -293,16 +299,17 @@ class SchemaParser internal constructor(scanResult: ScannedSchemaObjects, privat
}
}
- private fun determineOutputType(typeDefinition: Type<*>) =
- determineType(GraphQLOutputType::class, typeDefinition, permittedTypesForObject) as GraphQLOutputType
+ private fun determineOutputType(typeDefinition: Type<*>, inputObjects: List) =
+ determineType(GraphQLOutputType::class, typeDefinition, permittedTypesForObject, inputObjects) as GraphQLOutputType
- private fun determineInputType(typeDefinition: Type<*>) =
- determineType(GraphQLInputType::class, typeDefinition, permittedTypesForInputObject) as GraphQLInputType
-
- private fun determineType(expectedType: KClass, typeDefinition: Type<*>, allowedTypeReferences: Set): GraphQLType =
+ private fun determineType(expectedType: KClass, typeDefinition: Type<*>, allowedTypeReferences: Set, inputObjects: List): GraphQLType =
when (typeDefinition) {
- is ListType -> GraphQLList(determineType(expectedType, typeDefinition.type, allowedTypeReferences))
- is NonNullType -> GraphQLNonNull(determineType(expectedType, typeDefinition.type, allowedTypeReferences))
+ is ListType -> GraphQLList(determineType(expectedType, typeDefinition.type, allowedTypeReferences, inputObjects))
+ is NonNullType -> GraphQLNonNull(determineType(expectedType, typeDefinition.type, allowedTypeReferences, inputObjects))
+ is InputObjectTypeDefinition -> {
+ log.info("Create input object")
+ createInputObject(typeDefinition, inputObjects)
+ }
is TypeName -> {
val scalarType = customScalars[typeDefinition.name] ?: graphQLScalars[typeDefinition.name]
if (scalarType != null) {
@@ -318,6 +325,45 @@ class SchemaParser internal constructor(scanResult: ScannedSchemaObjects, privat
else -> throw SchemaError("Unknown type: $typeDefinition")
}
+ private fun determineInputType(typeDefinition: Type<*>, inputObjects: List) =
+ determineInputType(GraphQLInputType::class, typeDefinition, permittedTypesForInputObject, inputObjects) as GraphQLInputType
+
+ private fun determineInputType(expectedType: KClass, typeDefinition: Type<*>, allowedTypeReferences: Set, inputObjects: List): GraphQLType =
+ when (typeDefinition) {
+ is ListType -> GraphQLList(determineType(expectedType, typeDefinition.type, allowedTypeReferences, inputObjects))
+ is NonNullType -> GraphQLNonNull(determineType(expectedType, typeDefinition.type, allowedTypeReferences, inputObjects))
+ is InputObjectTypeDefinition -> {
+ log.info("Create input object")
+ createInputObject(typeDefinition, inputObjects)
+ }
+ is TypeName -> {
+ val scalarType = customScalars[typeDefinition.name] ?: graphQLScalars[typeDefinition.name]
+ if (scalarType != null) {
+ scalarType
+ } else {
+ if (!allowedTypeReferences.contains(typeDefinition.name)) {
+ throw SchemaError("Expected type '${typeDefinition.name}' to be a ${expectedType.simpleName}, but it wasn't! " +
+ "Was a type only permitted for object types incorrectly used as an input type, or vice-versa?")
+ }
+ val found = inputObjects.filter { it.name == typeDefinition.name }
+ if (found.size == 1) {
+ found[0]
+ } else {
+ val filteredDefinitions = inputObjectDefinitions.filter { it.name == typeDefinition.name }
+ if (filteredDefinitions.isNotEmpty()) {
+ val inputObject = createInputObject(filteredDefinitions[0], inputObjects)
+ (inputObjects as MutableList).add(inputObject)
+ inputObject
+ } else {
+ // todo: handle enum type
+ GraphQLTypeReference(typeDefinition.name)
+ }
+ }
+ }
+ }
+ else -> throw SchemaError("Unknown type: $typeDefinition")
+ }
+
/**
* Returns an optional [String] describing a deprecated field/enum.
* If a deprecation directive was defined using the @deprecated directive,
diff --git a/src/main/kotlin/graphql/kickstart/tools/directive/SchemaGeneratorDirectiveHelper.java b/src/main/kotlin/graphql/kickstart/tools/directive/SchemaGeneratorDirectiveHelper.java
index e064d8c0..0758688b 100644
--- a/src/main/kotlin/graphql/kickstart/tools/directive/SchemaGeneratorDirectiveHelper.java
+++ b/src/main/kotlin/graphql/kickstart/tools/directive/SchemaGeneratorDirectiveHelper.java
@@ -34,13 +34,14 @@
import static java.util.stream.Collectors.toList;
/**
- * This contains the helper code that allows {@link graphql.schema.idl.SchemaDirectiveWiring} implementations
- * to be invoked during schema generation.
+ * This contains the helper code that allows {@link graphql.schema.idl.SchemaDirectiveWiring} implementations to be
+ * invoked during schema generation.
*/
@Internal
public class SchemaGeneratorDirectiveHelper {
public static class Parameters {
+
private final TypeDefinitionRegistry typeRegistry;
private final RuntimeWiring runtimeWiring;
private final NodeParentTree nodeParentTree;
@@ -50,11 +51,15 @@ public static class Parameters {
private final GraphQLFieldsContainer fieldsContainer;
private final GraphQLFieldDefinition fieldDefinition;
- public Parameters(TypeDefinitionRegistry typeRegistry, RuntimeWiring runtimeWiring, Map context, GraphQLCodeRegistry.Builder codeRegistry) {
+ public Parameters(TypeDefinitionRegistry typeRegistry, RuntimeWiring runtimeWiring, Map context,
+ GraphQLCodeRegistry.Builder codeRegistry) {
this(typeRegistry, runtimeWiring, context, codeRegistry, null, null, null, null);
}
- Parameters(TypeDefinitionRegistry typeRegistry, RuntimeWiring runtimeWiring, Map context, GraphQLCodeRegistry.Builder codeRegistry, NodeParentTree nodeParentTree, GraphqlElementParentTree elementParentTree, GraphQLFieldsContainer fieldsContainer, GraphQLFieldDefinition fieldDefinition) {
+ Parameters(TypeDefinitionRegistry typeRegistry, RuntimeWiring runtimeWiring, Map context,
+ GraphQLCodeRegistry.Builder codeRegistry, NodeParentTree nodeParentTree,
+ GraphqlElementParentTree elementParentTree, GraphQLFieldsContainer fieldsContainer,
+ GraphQLFieldDefinition fieldDefinition) {
this.typeRegistry = typeRegistry;
this.runtimeWiring = runtimeWiring;
this.nodeParentTree = nodeParentTree;
@@ -97,16 +102,21 @@ public GraphQLFieldDefinition getFieldsDefinition() {
return fieldDefinition;
}
- public Parameters newParams(GraphQLFieldsContainer fieldsContainer, NodeParentTree nodeParentTree, GraphqlElementParentTree elementParentTree) {
- return new Parameters(this.typeRegistry, this.runtimeWiring, this.context, this.codeRegistry, nodeParentTree, elementParentTree, fieldsContainer, fieldDefinition);
+ public Parameters newParams(GraphQLFieldsContainer fieldsContainer, NodeParentTree nodeParentTree,
+ GraphqlElementParentTree elementParentTree) {
+ return new Parameters(this.typeRegistry, this.runtimeWiring, this.context, this.codeRegistry, nodeParentTree,
+ elementParentTree, fieldsContainer, fieldDefinition);
}
- public Parameters newParams(GraphQLFieldDefinition fieldDefinition, GraphQLFieldsContainer fieldsContainer, NodeParentTree nodeParentTree, GraphqlElementParentTree elementParentTree) {
- return new Parameters(this.typeRegistry, this.runtimeWiring, this.context, this.codeRegistry, nodeParentTree, elementParentTree, fieldsContainer, fieldDefinition);
+ public Parameters newParams(GraphQLFieldDefinition fieldDefinition, GraphQLFieldsContainer fieldsContainer,
+ NodeParentTree nodeParentTree, GraphqlElementParentTree elementParentTree) {
+ return new Parameters(this.typeRegistry, this.runtimeWiring, this.context, this.codeRegistry, nodeParentTree,
+ elementParentTree, fieldsContainer, fieldDefinition);
}
public Parameters newParams(NodeParentTree nodeParentTree, GraphqlElementParentTree elementParentTree) {
- return new Parameters(this.typeRegistry, this.runtimeWiring, this.context, this.codeRegistry, nodeParentTree, elementParentTree, this.fieldsContainer, fieldDefinition);
+ return new Parameters(this.typeRegistry, this.runtimeWiring, this.context, this.codeRegistry, nodeParentTree,
+ elementParentTree, this.fieldsContainer, fieldDefinition);
}
}
@@ -126,10 +136,13 @@ private GraphqlElementParentTree buildRuntimeTree(GraphQLSchemaElement... elemen
return new GraphqlElementParentTree(nodeStack);
}
- private List wireArguments(GraphQLFieldDefinition fieldDefinition, GraphQLFieldsContainer fieldsContainer, NamedNode fieldsContainerNode, Parameters params, GraphQLFieldDefinition field) {
+ private List wireArguments(GraphQLFieldDefinition fieldDefinition,
+ GraphQLFieldsContainer fieldsContainer, NamedNode fieldsContainerNode, Parameters params,
+ GraphQLFieldDefinition field) {
return field.getArguments().stream().map(argument -> {
- NodeParentTree nodeParentTree = buildAstTree(fieldsContainerNode, field.getDefinition(), argument.getDefinition());
+ NodeParentTree nodeParentTree = buildAstTree(fieldsContainerNode, field.getDefinition(),
+ argument.getDefinition());
GraphqlElementParentTree elementParentTree = buildRuntimeTree(fieldsContainer, field, argument);
Parameters argParams = params.newParams(fieldDefinition, fieldsContainer, nodeParentTree, elementParentTree);
@@ -138,11 +151,13 @@ private List wireArguments(GraphQLFieldDefinition fieldDefiniti
}).collect(toList());
}
- private List wireFields(GraphQLFieldsContainer fieldsContainer, NamedNode fieldsContainerNode, Parameters params) {
+ private List wireFields(GraphQLFieldsContainer fieldsContainer, NamedNode fieldsContainerNode,
+ Parameters params) {
return fieldsContainer.getFieldDefinitions().stream().map(fieldDefinition -> {
// and for each argument in the fieldDefinition run the wiring for them - and note that they can change
- List newArgs = wireArguments(fieldDefinition, fieldsContainer, fieldsContainerNode, params, fieldDefinition);
+ List newArgs = wireArguments(fieldDefinition, fieldsContainer, fieldsContainerNode, params,
+ fieldDefinition);
// they may have changed the arguments to the fieldDefinition so reflect that
fieldDefinition = fieldDefinition.transform(builder -> builder.clearArguments().arguments(newArgs));
@@ -189,7 +204,8 @@ public GraphQLEnumType onEnum(GraphQLEnumType enumType, Parameters params) {
List newEnums = enumType.getValues().stream().map(enumValueDefinition -> {
- NodeParentTree nodeParentTree = buildAstTree(enumType.getDefinition(), enumValueDefinition.getDefinition());
+ NodeParentTree nodeParentTree = buildAstTree(enumType.getDefinition(),
+ enumValueDefinition.getDefinition());
GraphqlElementParentTree elementParentTree = buildRuntimeTree(enumType, enumValueDefinition);
Parameters fieldParams = params.newParams(nodeParentTree, elementParentTree);
@@ -211,7 +227,8 @@ public GraphQLEnumType onEnum(GraphQLEnumType enumType, Parameters params) {
public GraphQLInputObjectType onInputObjectType(GraphQLInputObjectType inputObjectType, Parameters params) {
List newFields = inputObjectType.getFieldDefinitions().stream().map(inputField -> {
- NodeParentTree nodeParentTree = buildAstTree(inputObjectType.getDefinition(), inputField.getDefinition());
+ NodeParentTree nodeParentTree = buildAstTree(inputObjectType.getDefinition(),
+ inputField.getDefinition());
GraphqlElementParentTree elementParentTree = buildRuntimeTree(inputObjectType, inputField);
Parameters fieldParams = params.newParams(nodeParentTree, elementParentTree);
@@ -219,7 +236,8 @@ public GraphQLInputObjectType onInputObjectType(GraphQLInputObjectType inputObje
return onInputObjectField(inputField, fieldParams);
}).collect(toList());
- GraphQLInputObjectType newInputObjectType = inputObjectType.transform(builder -> builder.clearFields().fields(newFields));
+ GraphQLInputObjectType newInputObjectType = inputObjectType
+ .transform(builder -> builder.clearFields().fields(newFields));
NodeParentTree nodeParentTree = buildAstTree(newInputObjectType.getDefinition());
GraphqlElementParentTree elementParentTree = buildRuntimeTree(newInputObjectType);
@@ -280,13 +298,16 @@ private GraphQLArgument onArgument(GraphQLArgument argument, Parameters params)
// builds a type safe SchemaDirectiveWiringEnvironment
//
interface EnvBuilder {
- SchemaDirectiveWiringEnvironment apply(T outputElement, List allDirectives, GraphQLDirective registeredDirective);
+
+ SchemaDirectiveWiringEnvironment apply(T outputElement, List allDirectives,
+ GraphQLDirective registeredDirective);
}
//
// invokes the SchemaDirectiveWiring with the provided environment
//
interface EnvInvoker {
+
T apply(SchemaDirectiveWiring schemaDirectiveWiring, SchemaDirectiveWiringEnvironment env);
}
@@ -322,16 +343,19 @@ private T wireDirectives(
// wiring factory is last (if present)
env = envBuilder.apply(outputObject, allDirectives, null);
if (wiringFactory.providesSchemaDirectiveWiring(env)) {
- schemaDirectiveWiring = assertNotNull(wiringFactory.getSchemaDirectiveWiring(env), "Your WiringFactory MUST provide a non null SchemaDirectiveWiring");
+ schemaDirectiveWiring = assertNotNull(wiringFactory.getSchemaDirectiveWiring(env),
+ "Your WiringFactory MUST provide a non null SchemaDirectiveWiring");
outputObject = invokeWiring(outputObject, invoker, schemaDirectiveWiring, env);
}
return outputObject;
}
- private T invokeWiring(T element, EnvInvoker invoker, SchemaDirectiveWiring schemaDirectiveWiring, SchemaDirectiveWiringEnvironment env) {
+ private T invokeWiring(T element, EnvInvoker invoker,
+ SchemaDirectiveWiring schemaDirectiveWiring, SchemaDirectiveWiringEnvironment env) {
T newElement = invoker.apply(schemaDirectiveWiring, env);
- assertNotNull(newElement, "The SchemaDirectiveWiring MUST return a non null return value for element '" + element.getName() + "'");
+ assertNotNull(newElement,
+ "The SchemaDirectiveWiring MUST return a non null return value for element '" + element.getName() + "'");
return newElement;
}
}