diff --git a/README.md b/README.md index c9bddb66..2206279c 100644 --- a/README.md +++ b/README.md @@ -55,11 +55,11 @@ A few libraries exist to ease the boilerplate pain, including [GraphQL-Java's bu com.graphql-java-kickstart graphql-java-tools - 6.0.0 + 6.0.2 ``` ```groovy -compile 'com.graphql-java-kickstart:graphql-java-tools:6.0.0' +compile 'com.graphql-java-kickstart:graphql-java-tools:6.0.2' ``` New releases will be available faster in the JCenter repository than in Maven Central. Add the following to use for Maven diff --git a/pom.xml b/pom.xml index 042ce41d..56dcae9b 100644 --- a/pom.xml +++ b/pom.xml @@ -4,7 +4,7 @@ com.graphql-java-kickstart graphql-java-tools - 6.0.2-SNAPSHOT + 6.0.3-SNAPSHOT jar GraphQL Java Tools diff --git a/src/main/kotlin/graphql/kickstart/tools/SchemaParser.kt b/src/main/kotlin/graphql/kickstart/tools/SchemaParser.kt index 42bde120..8dba70ce 100644 --- a/src/main/kotlin/graphql/kickstart/tools/SchemaParser.kt +++ b/src/main/kotlin/graphql/kickstart/tools/SchemaParser.kt @@ -64,10 +64,16 @@ class SchemaParser internal constructor(scanResult: ScannedSchemaObjects, privat fun parseSchemaObjects(): SchemaObjects { // Create GraphQL objects - val interfaces = interfaceDefinitions.map { createInterfaceObject(it) } - val objects = objectDefinitions.map { createObject(it, interfaces) } +// val inputObjects = inputObjectDefinitions.map { createInputObject(it, listOf())} + val inputObjects: MutableList = mutableListOf() + inputObjectDefinitions.forEach { + if (inputObjects.none { io -> io.name == it.name }) { + inputObjects.add(createInputObject(it, inputObjects)) + } + } + val interfaces = interfaceDefinitions.map { createInterfaceObject(it, inputObjects) } + val objects = objectDefinitions.map { createObject(it, interfaces, inputObjects) } val unions = unionDefinitions.map { createUnionObject(it, objects) } - val inputObjects = inputObjectDefinitions.map { createInputObject(it) } val enums = enumDefinitions.map { createEnumObject(it) } // Assign type resolver to interfaces now that we know all of the object types @@ -103,7 +109,7 @@ class SchemaParser internal constructor(scanResult: ScannedSchemaObjects, privat @Suppress("unused") fun getUnusedDefinitions(): Set> = unusedDefinitions - private fun createObject(objectDefinition: ObjectTypeDefinition, interfaces: List): GraphQLObjectType { + private fun createObject(objectDefinition: ObjectTypeDefinition, interfaces: List, inputObjects: List): GraphQLObjectType { val name = objectDefinition.name val builder = GraphQLObjectType.newObject() .name(name) @@ -121,7 +127,7 @@ class SchemaParser internal constructor(scanResult: ScannedSchemaObjects, privat objectDefinition.getExtendedFieldDefinitions(extensionDefinitions).forEach { fieldDefinition -> fieldDefinition.description builder.field { field -> - createField(field, fieldDefinition) + createField(field, fieldDefinition, inputObjects) codeRegistryBuilder.dataFetcher( FieldCoordinates.coordinates(objectDefinition.name, fieldDefinition.name), fieldResolversByType[objectDefinition]?.get(fieldDefinition)?.createDataFetcher() @@ -153,7 +159,7 @@ class SchemaParser internal constructor(scanResult: ScannedSchemaObjects, privat return output.toTypedArray() } - private fun createInputObject(definition: InputObjectTypeDefinition): GraphQLInputObjectType { + private fun createInputObject(definition: InputObjectTypeDefinition, inputObjects: List): GraphQLInputObjectType { val builder = GraphQLInputObjectType.newInputObject() .name(definition.name) .definition(definition) @@ -167,7 +173,7 @@ class SchemaParser internal constructor(scanResult: ScannedSchemaObjects, privat .definition(inputDefinition) .description(if (inputDefinition.description != null) inputDefinition.description.content else getDocumentation(inputDefinition)) .defaultValue(buildDefaultValue(inputDefinition.defaultValue)) - .type(determineInputType(inputDefinition.type)) + .type(determineInputType(inputDefinition.type, inputObjects)) .withDirectives(*buildDirectives(inputDefinition.directives, setOf(), Introspection.DirectiveLocation.INPUT_FIELD_DEFINITION)) builder.field(fieldBuilder.build()) } @@ -210,7 +216,7 @@ class SchemaParser internal constructor(scanResult: ScannedSchemaObjects, privat return directiveGenerator.onEnum(builder.build(), DirectiveBehavior.Params(runtimeWiring, codeRegistryBuilder)) } - private fun createInterfaceObject(interfaceDefinition: InterfaceTypeDefinition): GraphQLInterfaceType { + private fun createInterfaceObject(interfaceDefinition: InterfaceTypeDefinition, inputObjects: List): GraphQLInterfaceType { val name = interfaceDefinition.name val builder = GraphQLInterfaceType.newInterface() .name(name) @@ -220,7 +226,7 @@ class SchemaParser internal constructor(scanResult: ScannedSchemaObjects, privat builder.withDirectives(*buildDirectives(interfaceDefinition.directives, setOf(), Introspection.DirectiveLocation.INTERFACE)) interfaceDefinition.fieldDefinitions.forEach { fieldDefinition -> - builder.field { field -> createField(field, fieldDefinition) } + builder.field { field -> createField(field, fieldDefinition, inputObjects) } } return directiveGenerator.onInterface(builder.build(), DirectiveBehavior.Params(runtimeWiring, codeRegistryBuilder)) @@ -259,19 +265,19 @@ class SchemaParser internal constructor(scanResult: ScannedSchemaObjects, privat return leafObjects } - private fun createField(field: GraphQLFieldDefinition.Builder, fieldDefinition: FieldDefinition): GraphQLFieldDefinition.Builder { + private fun createField(field: GraphQLFieldDefinition.Builder, fieldDefinition: FieldDefinition, inputObjects: List): GraphQLFieldDefinition.Builder { field.name(fieldDefinition.name) field.description(if (fieldDefinition.description != null) fieldDefinition.description.content else getDocumentation(fieldDefinition)) field.definition(fieldDefinition) getDeprecated(fieldDefinition.directives)?.let { field.deprecate(it) } - field.type(determineOutputType(fieldDefinition.type)) + field.type(determineOutputType(fieldDefinition.type, inputObjects)) fieldDefinition.inputValueDefinitions.forEach { argumentDefinition -> val argumentBuilder = GraphQLArgument.newArgument() .name(argumentDefinition.name) .definition(argumentDefinition) .description(if (argumentDefinition.description != null) argumentDefinition.description.content else getDocumentation(argumentDefinition)) .defaultValue(buildDefaultValue(argumentDefinition.defaultValue)) - .type(determineInputType(argumentDefinition.type)) + .type(determineInputType(argumentDefinition.type, inputObjects)) .withDirectives(*buildDirectives(argumentDefinition.directives, setOf(), Introspection.DirectiveLocation.ARGUMENT_DEFINITION)) field.argument(argumentBuilder.build()) } @@ -293,16 +299,17 @@ class SchemaParser internal constructor(scanResult: ScannedSchemaObjects, privat } } - private fun determineOutputType(typeDefinition: Type<*>) = - determineType(GraphQLOutputType::class, typeDefinition, permittedTypesForObject) as GraphQLOutputType + private fun determineOutputType(typeDefinition: Type<*>, inputObjects: List) = + determineType(GraphQLOutputType::class, typeDefinition, permittedTypesForObject, inputObjects) as GraphQLOutputType - private fun determineInputType(typeDefinition: Type<*>) = - determineType(GraphQLInputType::class, typeDefinition, permittedTypesForInputObject) as GraphQLInputType - - private fun determineType(expectedType: KClass, typeDefinition: Type<*>, allowedTypeReferences: Set): GraphQLType = + private fun determineType(expectedType: KClass, typeDefinition: Type<*>, allowedTypeReferences: Set, inputObjects: List): GraphQLType = when (typeDefinition) { - is ListType -> GraphQLList(determineType(expectedType, typeDefinition.type, allowedTypeReferences)) - is NonNullType -> GraphQLNonNull(determineType(expectedType, typeDefinition.type, allowedTypeReferences)) + is ListType -> GraphQLList(determineType(expectedType, typeDefinition.type, allowedTypeReferences, inputObjects)) + is NonNullType -> GraphQLNonNull(determineType(expectedType, typeDefinition.type, allowedTypeReferences, inputObjects)) + is InputObjectTypeDefinition -> { + log.info("Create input object") + createInputObject(typeDefinition, inputObjects) + } is TypeName -> { val scalarType = customScalars[typeDefinition.name] ?: graphQLScalars[typeDefinition.name] if (scalarType != null) { @@ -318,6 +325,45 @@ class SchemaParser internal constructor(scanResult: ScannedSchemaObjects, privat else -> throw SchemaError("Unknown type: $typeDefinition") } + private fun determineInputType(typeDefinition: Type<*>, inputObjects: List) = + determineInputType(GraphQLInputType::class, typeDefinition, permittedTypesForInputObject, inputObjects) as GraphQLInputType + + private fun determineInputType(expectedType: KClass, typeDefinition: Type<*>, allowedTypeReferences: Set, inputObjects: List): GraphQLType = + when (typeDefinition) { + is ListType -> GraphQLList(determineType(expectedType, typeDefinition.type, allowedTypeReferences, inputObjects)) + is NonNullType -> GraphQLNonNull(determineType(expectedType, typeDefinition.type, allowedTypeReferences, inputObjects)) + is InputObjectTypeDefinition -> { + log.info("Create input object") + createInputObject(typeDefinition, inputObjects) + } + is TypeName -> { + val scalarType = customScalars[typeDefinition.name] ?: graphQLScalars[typeDefinition.name] + if (scalarType != null) { + scalarType + } else { + if (!allowedTypeReferences.contains(typeDefinition.name)) { + throw SchemaError("Expected type '${typeDefinition.name}' to be a ${expectedType.simpleName}, but it wasn't! " + + "Was a type only permitted for object types incorrectly used as an input type, or vice-versa?") + } + val found = inputObjects.filter { it.name == typeDefinition.name } + if (found.size == 1) { + found[0] + } else { + val filteredDefinitions = inputObjectDefinitions.filter { it.name == typeDefinition.name } + if (filteredDefinitions.isNotEmpty()) { + val inputObject = createInputObject(filteredDefinitions[0], inputObjects) + (inputObjects as MutableList).add(inputObject) + inputObject + } else { + // todo: handle enum type + GraphQLTypeReference(typeDefinition.name) + } + } + } + } + else -> throw SchemaError("Unknown type: $typeDefinition") + } + /** * Returns an optional [String] describing a deprecated field/enum. * If a deprecation directive was defined using the @deprecated directive, diff --git a/src/main/kotlin/graphql/kickstart/tools/directive/SchemaGeneratorDirectiveHelper.java b/src/main/kotlin/graphql/kickstart/tools/directive/SchemaGeneratorDirectiveHelper.java index e064d8c0..0758688b 100644 --- a/src/main/kotlin/graphql/kickstart/tools/directive/SchemaGeneratorDirectiveHelper.java +++ b/src/main/kotlin/graphql/kickstart/tools/directive/SchemaGeneratorDirectiveHelper.java @@ -34,13 +34,14 @@ import static java.util.stream.Collectors.toList; /** - * This contains the helper code that allows {@link graphql.schema.idl.SchemaDirectiveWiring} implementations - * to be invoked during schema generation. + * This contains the helper code that allows {@link graphql.schema.idl.SchemaDirectiveWiring} implementations to be + * invoked during schema generation. */ @Internal public class SchemaGeneratorDirectiveHelper { public static class Parameters { + private final TypeDefinitionRegistry typeRegistry; private final RuntimeWiring runtimeWiring; private final NodeParentTree nodeParentTree; @@ -50,11 +51,15 @@ public static class Parameters { private final GraphQLFieldsContainer fieldsContainer; private final GraphQLFieldDefinition fieldDefinition; - public Parameters(TypeDefinitionRegistry typeRegistry, RuntimeWiring runtimeWiring, Map context, GraphQLCodeRegistry.Builder codeRegistry) { + public Parameters(TypeDefinitionRegistry typeRegistry, RuntimeWiring runtimeWiring, Map context, + GraphQLCodeRegistry.Builder codeRegistry) { this(typeRegistry, runtimeWiring, context, codeRegistry, null, null, null, null); } - Parameters(TypeDefinitionRegistry typeRegistry, RuntimeWiring runtimeWiring, Map context, GraphQLCodeRegistry.Builder codeRegistry, NodeParentTree nodeParentTree, GraphqlElementParentTree elementParentTree, GraphQLFieldsContainer fieldsContainer, GraphQLFieldDefinition fieldDefinition) { + Parameters(TypeDefinitionRegistry typeRegistry, RuntimeWiring runtimeWiring, Map context, + GraphQLCodeRegistry.Builder codeRegistry, NodeParentTree nodeParentTree, + GraphqlElementParentTree elementParentTree, GraphQLFieldsContainer fieldsContainer, + GraphQLFieldDefinition fieldDefinition) { this.typeRegistry = typeRegistry; this.runtimeWiring = runtimeWiring; this.nodeParentTree = nodeParentTree; @@ -97,16 +102,21 @@ public GraphQLFieldDefinition getFieldsDefinition() { return fieldDefinition; } - public Parameters newParams(GraphQLFieldsContainer fieldsContainer, NodeParentTree nodeParentTree, GraphqlElementParentTree elementParentTree) { - return new Parameters(this.typeRegistry, this.runtimeWiring, this.context, this.codeRegistry, nodeParentTree, elementParentTree, fieldsContainer, fieldDefinition); + public Parameters newParams(GraphQLFieldsContainer fieldsContainer, NodeParentTree nodeParentTree, + GraphqlElementParentTree elementParentTree) { + return new Parameters(this.typeRegistry, this.runtimeWiring, this.context, this.codeRegistry, nodeParentTree, + elementParentTree, fieldsContainer, fieldDefinition); } - public Parameters newParams(GraphQLFieldDefinition fieldDefinition, GraphQLFieldsContainer fieldsContainer, NodeParentTree nodeParentTree, GraphqlElementParentTree elementParentTree) { - return new Parameters(this.typeRegistry, this.runtimeWiring, this.context, this.codeRegistry, nodeParentTree, elementParentTree, fieldsContainer, fieldDefinition); + public Parameters newParams(GraphQLFieldDefinition fieldDefinition, GraphQLFieldsContainer fieldsContainer, + NodeParentTree nodeParentTree, GraphqlElementParentTree elementParentTree) { + return new Parameters(this.typeRegistry, this.runtimeWiring, this.context, this.codeRegistry, nodeParentTree, + elementParentTree, fieldsContainer, fieldDefinition); } public Parameters newParams(NodeParentTree nodeParentTree, GraphqlElementParentTree elementParentTree) { - return new Parameters(this.typeRegistry, this.runtimeWiring, this.context, this.codeRegistry, nodeParentTree, elementParentTree, this.fieldsContainer, fieldDefinition); + return new Parameters(this.typeRegistry, this.runtimeWiring, this.context, this.codeRegistry, nodeParentTree, + elementParentTree, this.fieldsContainer, fieldDefinition); } } @@ -126,10 +136,13 @@ private GraphqlElementParentTree buildRuntimeTree(GraphQLSchemaElement... elemen return new GraphqlElementParentTree(nodeStack); } - private List wireArguments(GraphQLFieldDefinition fieldDefinition, GraphQLFieldsContainer fieldsContainer, NamedNode fieldsContainerNode, Parameters params, GraphQLFieldDefinition field) { + private List wireArguments(GraphQLFieldDefinition fieldDefinition, + GraphQLFieldsContainer fieldsContainer, NamedNode fieldsContainerNode, Parameters params, + GraphQLFieldDefinition field) { return field.getArguments().stream().map(argument -> { - NodeParentTree nodeParentTree = buildAstTree(fieldsContainerNode, field.getDefinition(), argument.getDefinition()); + NodeParentTree nodeParentTree = buildAstTree(fieldsContainerNode, field.getDefinition(), + argument.getDefinition()); GraphqlElementParentTree elementParentTree = buildRuntimeTree(fieldsContainer, field, argument); Parameters argParams = params.newParams(fieldDefinition, fieldsContainer, nodeParentTree, elementParentTree); @@ -138,11 +151,13 @@ private List wireArguments(GraphQLFieldDefinition fieldDefiniti }).collect(toList()); } - private List wireFields(GraphQLFieldsContainer fieldsContainer, NamedNode fieldsContainerNode, Parameters params) { + private List wireFields(GraphQLFieldsContainer fieldsContainer, NamedNode fieldsContainerNode, + Parameters params) { return fieldsContainer.getFieldDefinitions().stream().map(fieldDefinition -> { // and for each argument in the fieldDefinition run the wiring for them - and note that they can change - List newArgs = wireArguments(fieldDefinition, fieldsContainer, fieldsContainerNode, params, fieldDefinition); + List newArgs = wireArguments(fieldDefinition, fieldsContainer, fieldsContainerNode, params, + fieldDefinition); // they may have changed the arguments to the fieldDefinition so reflect that fieldDefinition = fieldDefinition.transform(builder -> builder.clearArguments().arguments(newArgs)); @@ -189,7 +204,8 @@ public GraphQLEnumType onEnum(GraphQLEnumType enumType, Parameters params) { List newEnums = enumType.getValues().stream().map(enumValueDefinition -> { - NodeParentTree nodeParentTree = buildAstTree(enumType.getDefinition(), enumValueDefinition.getDefinition()); + NodeParentTree nodeParentTree = buildAstTree(enumType.getDefinition(), + enumValueDefinition.getDefinition()); GraphqlElementParentTree elementParentTree = buildRuntimeTree(enumType, enumValueDefinition); Parameters fieldParams = params.newParams(nodeParentTree, elementParentTree); @@ -211,7 +227,8 @@ public GraphQLEnumType onEnum(GraphQLEnumType enumType, Parameters params) { public GraphQLInputObjectType onInputObjectType(GraphQLInputObjectType inputObjectType, Parameters params) { List newFields = inputObjectType.getFieldDefinitions().stream().map(inputField -> { - NodeParentTree nodeParentTree = buildAstTree(inputObjectType.getDefinition(), inputField.getDefinition()); + NodeParentTree nodeParentTree = buildAstTree(inputObjectType.getDefinition(), + inputField.getDefinition()); GraphqlElementParentTree elementParentTree = buildRuntimeTree(inputObjectType, inputField); Parameters fieldParams = params.newParams(nodeParentTree, elementParentTree); @@ -219,7 +236,8 @@ public GraphQLInputObjectType onInputObjectType(GraphQLInputObjectType inputObje return onInputObjectField(inputField, fieldParams); }).collect(toList()); - GraphQLInputObjectType newInputObjectType = inputObjectType.transform(builder -> builder.clearFields().fields(newFields)); + GraphQLInputObjectType newInputObjectType = inputObjectType + .transform(builder -> builder.clearFields().fields(newFields)); NodeParentTree nodeParentTree = buildAstTree(newInputObjectType.getDefinition()); GraphqlElementParentTree elementParentTree = buildRuntimeTree(newInputObjectType); @@ -280,13 +298,16 @@ private GraphQLArgument onArgument(GraphQLArgument argument, Parameters params) // builds a type safe SchemaDirectiveWiringEnvironment // interface EnvBuilder { - SchemaDirectiveWiringEnvironment apply(T outputElement, List allDirectives, GraphQLDirective registeredDirective); + + SchemaDirectiveWiringEnvironment apply(T outputElement, List allDirectives, + GraphQLDirective registeredDirective); } // // invokes the SchemaDirectiveWiring with the provided environment // interface EnvInvoker { + T apply(SchemaDirectiveWiring schemaDirectiveWiring, SchemaDirectiveWiringEnvironment env); } @@ -322,16 +343,19 @@ private T wireDirectives( // wiring factory is last (if present) env = envBuilder.apply(outputObject, allDirectives, null); if (wiringFactory.providesSchemaDirectiveWiring(env)) { - schemaDirectiveWiring = assertNotNull(wiringFactory.getSchemaDirectiveWiring(env), "Your WiringFactory MUST provide a non null SchemaDirectiveWiring"); + schemaDirectiveWiring = assertNotNull(wiringFactory.getSchemaDirectiveWiring(env), + "Your WiringFactory MUST provide a non null SchemaDirectiveWiring"); outputObject = invokeWiring(outputObject, invoker, schemaDirectiveWiring, env); } return outputObject; } - private T invokeWiring(T element, EnvInvoker invoker, SchemaDirectiveWiring schemaDirectiveWiring, SchemaDirectiveWiringEnvironment env) { + private T invokeWiring(T element, EnvInvoker invoker, + SchemaDirectiveWiring schemaDirectiveWiring, SchemaDirectiveWiringEnvironment env) { T newElement = invoker.apply(schemaDirectiveWiring, env); - assertNotNull(newElement, "The SchemaDirectiveWiring MUST return a non null return value for element '" + element.getName() + "'"); + assertNotNull(newElement, + "The SchemaDirectiveWiring MUST return a non null return value for element '" + element.getName() + "'"); return newElement; } }