Skip to content

Commit 00c379a

Browse files
committed
Add support for subscriptions (#4)
1 parent 730df19 commit 00c379a

File tree

6 files changed

+99
-3
lines changed

6 files changed

+99
-3
lines changed

README.md

+2
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ gql_view(request) # <-- the instance is callable and expects a `aiohttp.web.Req
5252
- `encoder`: the encoder to use for responses (sensibly defaults to `graphql_server.json_encode`)
5353
- `error_formatter`: the error formatter to use for responses (sensibly defaults to `graphql_server.default_format_error`)
5454
- `enable_async`: whether `async` mode will be enabled.
55+
- `subscriptions`: The [GraphiQL] socket endpoint for using subscriptions in [graphql-ws].
5556

5657

5758
## Testing
@@ -85,3 +86,4 @@ This project is licensed under the MIT License.
8586
[Apollo-Client]: http://dev.apollodata.com/core/network.html#query-batching
8687
[Devin Fee]: https://github.com/dfee
8788
[aiohttp-graphql]: https://github.com/graphql-python/aiohttp-graphql
89+
[graphql-ws]: https://github.com/graphql-python/graphql-ws

aiohttp_graphql/graphqlview.py

+3
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ def __init__(
3535
encoder=None,
3636
error_formatter=None,
3737
enable_async=True,
38+
subscriptions=None,
3839
**execution_options,
3940
):
4041
# pylint: disable=too-many-arguments
@@ -55,6 +56,7 @@ def __init__(
5556
self.encoder = encoder or json_encode
5657
self.error_formatter = error_formatter or default_format_error
5758
self.enable_async = enable_async and isinstance(self.executor, AsyncioExecutor)
59+
self.subscriptions = subscriptions
5860
self.execution_options = execution_options
5961
assert isinstance(
6062
self.schema, GraphQLSchema
@@ -97,6 +99,7 @@ def render_graphiql(self, params, result):
9799
result=result,
98100
graphiql_version=self.graphiql_version,
99101
graphiql_template=self.graphiql_template,
102+
subscriptions=self.subscriptions,
100103
)
101104

102105
def is_graphiql(self, request):

aiohttp_graphql/render_graphiql.py

+20-2
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@
2929
<script src="//cdn.jsdelivr.net/npm/[email protected]/umd/react.production.min.js"></script>
3030
<script src="//cdn.jsdelivr.net/npm/[email protected]/umd/react-dom.production.min.js"></script>
3131
<script src="//cdn.jsdelivr.net/npm/graphiql@{{graphiql_version}}/graphiql.min.js"></script>
32+
<script src="//cdn.jsdelivr.net/npm/[email protected]/browser/client.js"></script>
33+
<script src="//cdn.jsdelivr.net//npm/[email protected]/browser/client.js"></script>
3234
</head>
3335
<body>
3436
<script>
@@ -63,6 +65,20 @@
6365
otherParams[k] = parameters[k];
6466
}
6567
}
68+
69+
var subscriptionsFetcher;
70+
if ('{{subscriptions}}') {
71+
const subscriptionsClient = new SubscriptionsTransportWs.SubscriptionClient(
72+
'{{ subscriptions }}',
73+
{reconnect: true}
74+
);
75+
76+
subscriptionsFetcher = GraphiQLSubscriptionsFetcher.graphQLFetcher(
77+
subscriptionsClient,
78+
graphQLFetcher
79+
);
80+
}
81+
6682
var fetchURL = locationQuery(otherParams);
6783
6884
// Defines a GraphQL fetcher using the fetch API.
@@ -110,7 +126,7 @@
110126
// Render <GraphiQL /> into the body.
111127
ReactDOM.render(
112128
React.createElement(GraphiQL, {
113-
fetcher: graphQLFetcher,
129+
fetcher: subscriptionsFetcher || graphQLFetcher,
114130
onEditQuery: onEditQuery,
115131
onEditVariables: onEditVariables,
116132
onEditOperationName: onEditOperationName,
@@ -149,7 +165,7 @@ def process_var(template, name, value, jsonify=False):
149165

150166

151167
def simple_renderer(template, **values):
152-
replace = ["graphiql_version"]
168+
replace = ["graphiql_version", "subscriptions"]
153169
replace_jsonify = ["query", "result", "variables", "operation_name"]
154170

155171
for rep in replace:
@@ -167,6 +183,7 @@ async def render_graphiql(
167183
graphiql_template=None,
168184
params=None,
169185
result=None,
186+
subscriptions=None,
170187
):
171188
graphiql_version = graphiql_version or GRAPHIQL_VERSION
172189
template = graphiql_template or TEMPLATE
@@ -176,6 +193,7 @@ async def render_graphiql(
176193
"variables": params and params.variables,
177194
"operation_name": params and params.operation_name,
178195
"result": result,
196+
"subscriptions": subscriptions or "",
179197
}
180198

181199
if jinja_env:

tests/schema.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,16 @@ def resolve_raises(*args):
4949
},
5050
)
5151

52+
SubscriptionsRootType = GraphQLObjectType(
53+
name="SubscriptionsRoot",
54+
fields={
55+
"subscriptionsTest": GraphQLField(
56+
type=QueryRootType, resolver=lambda *args: QueryRootType
57+
)
58+
},
59+
)
5260

53-
Schema = GraphQLSchema(QueryRootType, MutationRootType)
61+
Schema = GraphQLSchema(QueryRootType, MutationRootType, SubscriptionsRootType)
5462

5563

5664
# Schema with async methods

tests/test_graphiqlview.py

+12
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,18 @@ async def test_graphiql_get_mutation(client, url_builder):
7878
assert "response: null" in await response.text()
7979

8080

81+
@pytest.mark.asyncio
82+
async def test_graphiql_get_subscriptions(client, url_builder):
83+
response = await client.get(
84+
url_builder(
85+
query=("subscription TestSubscriptions { subscriptionsTest { test } }")
86+
),
87+
headers={"Accept": "text/html"},
88+
)
89+
assert response.status == 200
90+
assert "response: null" in await response.text()
91+
92+
8193
class TestAsyncSchema:
8294
@pytest.fixture
8395
def executor(self, event_loop):

tests/test_graphqlview.py

+53
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@ async def test_errors_when_missing_operation_name(client, url_builder):
102102
query="""
103103
query TestQuery { test }
104104
mutation TestMutation { writeTest { test } }
105+
subscription TestSubscriptions { subscriptionsTest { test } }
105106
"""
106107
)
107108
)
@@ -156,6 +157,31 @@ async def test_errors_when_selecting_a_mutation_within_a_get(client, url_builder
156157
}
157158

158159

160+
@pytest.mark.asyncio
161+
async def test_errors_when_selecting_a_subscription_within_a_get(
162+
client, url_builder,
163+
):
164+
response = await client.get(
165+
url_builder(
166+
query="""
167+
subscription TestSubscriptions { subscriptionsTest { test } }
168+
""",
169+
operationName="TestSubscriptions",
170+
)
171+
)
172+
173+
assert response.status == 405
174+
assert await response.json() == {
175+
"errors": [
176+
{
177+
"message": (
178+
"Can only perform a subscription operation from a POST " "request."
179+
)
180+
},
181+
],
182+
}
183+
184+
159185
@pytest.mark.asyncio
160186
async def test_allows_mutation_to_exist_within_a_get(client, url_builder):
161187
response = await client.get(
@@ -196,6 +222,33 @@ async def test_allows_sending_a_mutation_via_post(client, base_url):
196222
assert await response.json() == {"data": {"writeTest": {"test": "Hello World"}}}
197223

198224

225+
@pytest.mark.asyncio
226+
async def test_errors_when_sending_a_subscription_without_allow(client, base_url):
227+
response = await client.post(
228+
base_url,
229+
data=json.dumps(
230+
dict(
231+
query="""
232+
subscription TestSubscriptions { subscriptionsTest { test } }
233+
""",
234+
)
235+
),
236+
headers={"content-type": "application/json"},
237+
)
238+
239+
assert response.status == 200
240+
assert await response.json() == {
241+
"data": None,
242+
"errors": [
243+
{
244+
"message": "Subscriptions are not allowed. You will need to "
245+
"either use the subscribe function or pass "
246+
"allow_subscriptions=True"
247+
},
248+
],
249+
}
250+
251+
199252
@pytest.mark.asyncio
200253
async def test_allows_post_with_url_encoding(client, base_url):
201254
data = FormData()

0 commit comments

Comments
 (0)