diff --git a/src/main/kotlin/graphql/kickstart/tools/MethodFieldResolver.kt b/src/main/kotlin/graphql/kickstart/tools/MethodFieldResolver.kt index 82fd0b79..54628faa 100644 --- a/src/main/kotlin/graphql/kickstart/tools/MethodFieldResolver.kt +++ b/src/main/kotlin/graphql/kickstart/tools/MethodFieldResolver.kt @@ -92,11 +92,10 @@ internal class MethodFieldResolver( } if (value == null && isOptional) { - if (environment.containsArgument(definition.name)) { - return@add Optional.empty() - } else { + if (options.inputArgumentOptionalDetectOmission && !environment.containsArgument(definition.name)) { return@add null } + return@add Optional.empty() } if (value != null diff --git a/src/main/kotlin/graphql/kickstart/tools/SchemaParserOptions.kt b/src/main/kotlin/graphql/kickstart/tools/SchemaParserOptions.kt index 8e29112a..c7c9c82b 100644 --- a/src/main/kotlin/graphql/kickstart/tools/SchemaParserOptions.kt +++ b/src/main/kotlin/graphql/kickstart/tools/SchemaParserOptions.kt @@ -25,6 +25,7 @@ data class SchemaParserOptions internal constructor( val allowUnimplementedResolvers: Boolean, val objectMapperProvider: PerFieldObjectMapperProvider, val proxyHandlers: List, + val inputArgumentOptionalDetectOmission: Boolean, val preferGraphQLResolver: Boolean, val introspectionEnabled: Boolean, val coroutineContextProvider: CoroutineContextProvider, @@ -50,6 +51,7 @@ data class SchemaParserOptions internal constructor( private var allowUnimplementedResolvers = false private var objectMapperProvider: PerFieldObjectMapperProvider = PerFieldConfiguringObjectMapperProvider() private val proxyHandlers: MutableList = mutableListOf(Spring4AopProxyHandler(), GuiceAopProxyHandler(), JavassistProxyHandler(), WeldProxyHandler()) + private var inputArgumentOptionalDetectOmission = false private var preferGraphQLResolver = false private var introspectionEnabled = true private var coroutineContextProvider: CoroutineContextProvider? = null @@ -80,6 +82,10 @@ data class SchemaParserOptions internal constructor( this.allowUnimplementedResolvers = allowUnimplementedResolvers } + fun inputArgumentOptionalDetectOmission(inputArgumentOptionalDetectOmission: Boolean) = this.apply { + this.inputArgumentOptionalDetectOmission = inputArgumentOptionalDetectOmission + } + fun preferGraphQLResolver(preferGraphQLResolver: Boolean) = this.apply { this.preferGraphQLResolver = preferGraphQLResolver } @@ -146,9 +152,18 @@ data class SchemaParserOptions internal constructor( genericWrappers } - return SchemaParserOptions(contextClass, wrappers, allowUnimplementedResolvers, objectMapperProvider, - proxyHandlers, preferGraphQLResolver, introspectionEnabled, coroutineContextProvider, - typeDefinitionFactories, fieldVisibility + return SchemaParserOptions( + contextClass, + wrappers, + allowUnimplementedResolvers, + objectMapperProvider, + proxyHandlers, + inputArgumentOptionalDetectOmission, + preferGraphQLResolver, + introspectionEnabled, + coroutineContextProvider, + typeDefinitionFactories, + fieldVisibility ) } } diff --git a/src/test/kotlin/graphql/kickstart/tools/MethodFieldResolverTest.kt b/src/test/kotlin/graphql/kickstart/tools/MethodFieldResolverTest.kt index 36535eb7..d2d69bd7 100644 --- a/src/test/kotlin/graphql/kickstart/tools/MethodFieldResolverTest.kt +++ b/src/test/kotlin/graphql/kickstart/tools/MethodFieldResolverTest.kt @@ -10,11 +10,101 @@ import org.junit.Test import java.lang.reflect.InvocationHandler import java.lang.reflect.Method import java.lang.reflect.Proxy +import java.util.* class MethodFieldResolverTest { @Test - fun shouldHandleScalarTypesAsMethodInputArgument() { + fun `should handle Optional type as method input argument`() { + val schema = SchemaParser.newParser() + .schemaString(""" + type Query { + testValue(input: String): String + testOmitted(input: String): String + testNull(input: String): String + } + """ + ) + .scalars(customScalarType) + .resolvers(object : GraphQLQueryResolver { + fun testValue(input: Optional) = input.toString() + fun testOmitted(input: Optional) = input.toString() + fun testNull(input: Optional) = input.toString() + }) + .build() + .makeExecutableSchema() + + val gql = GraphQL.newGraphQL(schema).build() + + val result = gql + .execute(ExecutionInput.newExecutionInput() + .query(""" + query { + testValue(input: "test-value") + testOmitted + testNull(input: null) + } + """) + .context(Object()) + .root(Object())) + + val expected = mapOf( + "testValue" to "Optional[test-value]", + "testOmitted" to "Optional.empty", + "testNull" to "Optional.empty" + ) + + Assert.assertEquals(expected, result.getData()) + } + + @Test + fun `should handle Optional type as method input argument with omission detection`() { + val schema = SchemaParser.newParser() + .schemaString(""" + type Query { + testValue(input: String): String + testOmitted(input: String): String + testNull(input: String): String + } + """ + ) + .scalars(customScalarType) + .resolvers(object : GraphQLQueryResolver { + fun testValue(input: Optional) = input.toString() + fun testOmitted(input: Optional?) = input.toString() + fun testNull(input: Optional) = input.toString() + }) + .options(SchemaParserOptions.newOptions() + .inputArgumentOptionalDetectOmission(true) + .build()) + .build() + .makeExecutableSchema() + + val gql = GraphQL.newGraphQL(schema).build() + + val result = gql + .execute(ExecutionInput.newExecutionInput() + .query(""" + query { + testValue(input: "test-value") + testOmitted + testNull(input: null) + } + """) + .context(Object()) + .root(Object())) + + val expected = mapOf( + "testValue" to "Optional[test-value]", + "testOmitted" to "null", + "testNull" to "Optional.empty" + ) + + Assert.assertEquals(expected, result.getData()) + } + + @Test + fun `should handle scalar types as method input argument`() { val schema = SchemaParser.newParser() .schemaString(""" scalar CustomScalar @@ -47,7 +137,7 @@ class MethodFieldResolverTest { } @Test - fun shouldHandleListsOfScalarTypes() { + fun `should handle lists of scalar types`() { val schema = SchemaParser.newParser() .schemaString(""" scalar CustomScalar @@ -80,7 +170,7 @@ class MethodFieldResolverTest { } @Test - fun shouldHandleProxies() { + fun `should handle proxies`() { val invocationHandler = object : InvocationHandler { override fun invoke(proxy: Any, method: Method, args: Array): Any { return when (method.name) {