diff --git a/src/main/java/graphql/servlet/AbstractGraphQLHttpServlet.java b/src/main/java/graphql/servlet/AbstractGraphQLHttpServlet.java index cf8bb041..ca4b06b4 100644 --- a/src/main/java/graphql/servlet/AbstractGraphQLHttpServlet.java +++ b/src/main/java/graphql/servlet/AbstractGraphQLHttpServlet.java @@ -3,6 +3,8 @@ import com.google.common.io.ByteStreams; import com.google.common.io.CharStreams; import graphql.ExecutionResult; +import graphql.GraphQL; +import graphql.execution.reactive.SingleSubscriberPublisher; import graphql.introspection.IntrospectionQuery; import graphql.schema.GraphQLFieldDefinition; import graphql.servlet.config.GraphQLConfiguration; @@ -13,11 +15,7 @@ import graphql.servlet.core.GraphQLServletListener; import graphql.servlet.core.internal.GraphQLRequest; import graphql.servlet.core.internal.VariableMapper; -import graphql.servlet.input.BatchInputPreProcessResult; -import graphql.servlet.input.BatchInputPreProcessor; -import graphql.servlet.input.GraphQLBatchedInvocationInput; -import graphql.servlet.input.GraphQLInvocationInputFactory; -import graphql.servlet.input.GraphQLSingleInvocationInput; +import graphql.servlet.input.*; import org.reactivestreams.Publisher; import org.reactivestreams.Subscriber; import org.reactivestreams.Subscription; @@ -28,24 +26,12 @@ import javax.servlet.AsyncEvent; import javax.servlet.AsyncListener; import javax.servlet.Servlet; -import javax.servlet.ServletException; import javax.servlet.http.HttpServlet; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; import javax.servlet.http.Part; -import java.io.BufferedInputStream; -import java.io.ByteArrayOutputStream; -import java.io.IOException; -import java.io.InputStream; -import java.io.Writer; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.HashMap; -import java.util.Iterator; -import java.util.List; -import java.util.Map; -import java.util.Objects; -import java.util.Optional; +import java.io.*; +import java.util.*; import java.util.concurrent.CountDownLatch; import java.util.concurrent.atomic.AtomicReference; import java.util.function.BiConsumer; @@ -354,13 +340,13 @@ private void doRequest(HttpServletRequest request, HttpServletResponse response, } @Override - protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException { + protected void doGet(HttpServletRequest req, HttpServletResponse resp) { init(); doRequestAsync(req, resp, getHandler); } @Override - protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException { + protected void doPost(HttpServletRequest req, HttpServletResponse resp) { init(); doRequestAsync(req, resp, postHandler); } @@ -373,7 +359,9 @@ private void query(GraphQLQueryInvoker queryInvoker, GraphQLObjectMapper graphQL HttpServletRequest req, HttpServletResponse resp) throws IOException { ExecutionResult result = queryInvoker.query(invocationInput); - if (!(result.getData() instanceof Publisher)) { + boolean isDeferred = Objects.nonNull(result.getExtensions()) && result.getExtensions().containsKey(GraphQL.DEFERRED_RESULTS); + + if (!(result.getData() instanceof Publisher || isDeferred)) { resp.setContentType(APPLICATION_JSON_UTF8); resp.setStatus(STATUS_OK); resp.getWriter().write(graphQLObjectMapper.serializeResultAsJson(result)); @@ -390,7 +378,16 @@ private void query(GraphQLQueryInvoker queryInvoker, GraphQLObjectMapper graphQL AtomicReference subscriptionRef = new AtomicReference<>(); asyncContext.addListener(new SubscriptionAsyncListener(subscriptionRef)); ExecutionResultSubscriber subscriber = new ExecutionResultSubscriber(subscriptionRef, asyncContext, graphQLObjectMapper); - ((Publisher) result.getData()).subscribe(subscriber); + List> publishers = new ArrayList<>(); + if (result.getData() instanceof Publisher) { + publishers.add(result.getData()); + } else { + publishers.add(new StaticDataPublisher<>(result)); + final Publisher deferredResultsPublisher = (Publisher) result.getExtensions().get(GraphQL.DEFERRED_RESULTS); + publishers.add(deferredResultsPublisher); + } + publishers.forEach(it -> it.subscribe(subscriber)); + if (isInAsyncThread) { // We need to delay the completion of async context until after the subscription has terminated, otherwise the AsyncContext is prematurely closed. try { @@ -537,7 +534,6 @@ public void onStartAsync(AsyncEvent event) { } } - private static class ExecutionResultSubscriber implements Subscriber { private final AtomicReference subscriptionRef; @@ -584,4 +580,13 @@ public void await() throws InterruptedException { completedLatch.await(); } } + + private static class StaticDataPublisher extends SingleSubscriberPublisher implements Publisher { + StaticDataPublisher(T data) { + super(); + super.offer(data); + super.noMoreData(); + } + } + } diff --git a/src/main/java/graphql/servlet/core/GraphQLObjectMapper.java b/src/main/java/graphql/servlet/core/GraphQLObjectMapper.java index 9ee6b6c8..1c25ea29 100644 --- a/src/main/java/graphql/servlet/core/GraphQLObjectMapper.java +++ b/src/main/java/graphql/servlet/core/GraphQLObjectMapper.java @@ -5,9 +5,8 @@ import com.fasterxml.jackson.databind.MappingIterator; import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.databind.ObjectReader; -import graphql.ExecutionResult; -import graphql.ExecutionResultImpl; -import graphql.GraphQLError; +import graphql.*; +import graphql.execution.ExecutionPath; import graphql.servlet.config.ConfiguringObjectMapperProvider; import graphql.servlet.config.ObjectMapperConfigurer; import graphql.servlet.config.ObjectMapperProvider; @@ -117,12 +116,19 @@ public ExecutionResult sanitizeErrors(ExecutionResult executionResult) { } else { errors = null; } - return new ExecutionResultImpl(data, errors, extensions); } public Map createResultFromExecutionResult(ExecutionResult executionResult) { - return convertSanitizedExecutionResult(sanitizeErrors(executionResult)); + ExecutionResult sanitizedExecutionResult = sanitizeErrors(executionResult); + if (executionResult instanceof DeferredExecutionResult) { + sanitizedExecutionResult = DeferredExecutionResultImpl + .newDeferredExecutionResult() + .from(executionResult) + .path(ExecutionPath.fromList(((DeferredExecutionResult) executionResult).getPath())) + .build(); + } + return convertSanitizedExecutionResult(sanitizedExecutionResult); } public Map convertSanitizedExecutionResult(ExecutionResult executionResult) { @@ -144,6 +150,10 @@ public Map convertSanitizedExecutionResult(ExecutionResult execu result.put("data", executionResult.getData()); } + if (executionResult instanceof DeferredExecutionResult) { + result.put("path", ((DeferredExecutionResult) executionResult).getPath()); + } + return result; } diff --git a/src/test/groovy/graphql/servlet/AbstractGraphQLHttpServletSpec.groovy b/src/test/groovy/graphql/servlet/AbstractGraphQLHttpServletSpec.groovy index d00307e2..91c5ec19 100644 --- a/src/test/groovy/graphql/servlet/AbstractGraphQLHttpServletSpec.groovy +++ b/src/test/groovy/graphql/servlet/AbstractGraphQLHttpServletSpec.groovy @@ -283,6 +283,28 @@ class AbstractGraphQLHttpServletSpec extends Specification { getBatchedResponseContent()[1].data.echo == "test" } + + def "deferred query over HTTP GET"() { + setup: + request.addParameter('query', 'query { echo(arg:"test") @defer }') + + when: + servlet.doGet(request, response) + + then: + response.getStatus() == STATUS_OK + response.getContentType() == CONTENT_TYPE_SERVER_SENT_EVENTS + getSubscriptionResponseContent()[0].data.echo == null + + when: + subscriptionLatch.await(1, TimeUnit.SECONDS) + + then: + def content = getSubscriptionResponseContent() + content[1].data == "test" + content[1].path == ["echo"] + } + def "Batch Execution Handler allows limiting batches and sending error messages."() { setup: servlet = TestUtils.createBatchCustomizedServlet({ env -> env.arguments.arg }, { env -> env.arguments.arg }, { env -> @@ -1030,6 +1052,61 @@ class AbstractGraphQLHttpServletSpec extends Specification { getSubscriptionResponseContent()[1].data.echo == "Second\n\ntest" } + def "defer query over HTTP POST"() { + setup: + request.setContent('{"query": "subscription Subscription($arg: String!) { echo(arg: $arg) }", "operationName": "Subscription", "variables": {"arg": "test"}}'.bytes) + request.setAsyncSupported(true) + + when: + servlet.doPost(request, response) + then: + response.getStatus() == STATUS_OK + response.getContentType() == CONTENT_TYPE_SERVER_SENT_EVENTS + + when: + subscriptionLatch.await(1, TimeUnit.SECONDS) + then: + getSubscriptionResponseContent()[0].data.echo == "First\n\ntest" + getSubscriptionResponseContent()[1].data.echo == "Second\n\ntest" + } + + def "deferred query that takes longer than initial results, should still be sent second"() { + setup: + servlet = TestUtils.createDefaultServlet({ env -> + if (env.getField().name == "a") { + Thread.sleep(1000) + } + env.arguments.arg + }) + request.setContent(mapper.writeValueAsBytes([ + query: ''' + { + object { + a(arg: "Hello") + b(arg: "World") @defer + } + } + ''' + ])) + request.setAsyncSupported(true) + + when: + servlet.doPost(request, response) + + then: + response.getStatus() == STATUS_OK + response.getContentType() == CONTENT_TYPE_SERVER_SENT_EVENTS + getSubscriptionResponseContent()[0].data.object.a == "Hello" // a has a Thread.sleep + + when: + subscriptionLatch.await(1, TimeUnit.SECONDS) + + then: + def content = getSubscriptionResponseContent() + content[1].data == "World" + content[1].path == ["object", "b"] + } + def "errors before graphql schema execution return internal server error"() { setup: servlet = SimpleGraphQLHttpServlet.newBuilder(GraphQLInvocationInputFactory.newBuilder { diff --git a/src/test/groovy/graphql/servlet/TestUtils.groovy b/src/test/groovy/graphql/servlet/TestUtils.groovy index 51e01edc..ec79d188 100644 --- a/src/test/groovy/graphql/servlet/TestUtils.groovy +++ b/src/test/groovy/graphql/servlet/TestUtils.groovy @@ -1,6 +1,7 @@ package graphql.servlet import com.google.common.io.ByteStreams +import graphql.Directives import graphql.Scalars import graphql.execution.reactive.SingleSubscriberPublisher import graphql.schema.* @@ -15,6 +16,7 @@ import graphql.servlet.core.ApolloScalars import graphql.servlet.input.BatchInputPreProcessor import graphql.servlet.context.ContextSetting +import java.util.concurrent.CompletableFuture import java.util.concurrent.atomic.AtomicReference class TestUtils { @@ -95,7 +97,7 @@ class TestUtils { static def createGraphQlSchema(DataFetcher queryDataFetcher = { env -> env.arguments.arg }, DataFetcher mutationDataFetcher = { env -> env.arguments.arg }, DataFetcher subscriptionDataFetcher = { env -> - AtomicReference> publisherRef = new AtomicReference<>(); + AtomicReference> publisherRef = new AtomicReference<>() publisherRef.set(new SingleSubscriberPublisher<>({ subscription -> publisherRef.get().offer(env.arguments.arg) publisherRef.get().noMoreData() @@ -113,6 +115,32 @@ class TestUtils { } field.dataFetcher(queryDataFetcher) } + .field { GraphQLFieldDefinition.Builder field -> + field.name("object") + field.type( + GraphQLObjectType.newObject() + .name("NestedObject") + .field { nested -> + nested.name("a") + nested.type(Scalars.GraphQLString) + nested.argument { argument -> + argument.name("arg") + argument.type(Scalars.GraphQLString) + } + nested.dataFetcher(queryDataFetcher) + } + .field { nested -> + nested.name("b") + nested.type(Scalars.GraphQLString) + nested.argument { argument -> + argument.name("arg") + argument.type(Scalars.GraphQLString) + } + nested.dataFetcher(queryDataFetcher) + } + ) + field.dataFetcher(new StaticDataFetcher([:])) + } .field { GraphQLFieldDefinition.Builder field -> field.name("returnsNullIncorrectly") field.type(new GraphQLNonNull(Scalars.GraphQLString)) @@ -174,6 +202,7 @@ class TestUtils { .mutation(mutation) .subscription(subscription) .additionalType(ApolloScalars.Upload) + .additionalDirective(Directives.DeferDirective) .build() }