Skip to content

Commit ae9a2da

Browse files
committed
websocket_test: Remove most manual closes
At one time this was necessary to prevent spurious warnings at shutdown, but not any more (and I intend to address warnings like this with a more general solution).
1 parent c350dc9 commit ae9a2da

File tree

1 file changed

+17
-71
lines changed

1 file changed

+17
-71
lines changed

tornado/test/websocket_test.py

+17-71
Original file line numberDiff line numberDiff line change
@@ -39,18 +39,21 @@
3939
class TestWebSocketHandler(WebSocketHandler):
4040
"""Base class for testing handlers that exposes the on_close event.
4141
42-
This allows for deterministic cleanup of the associated socket.
42+
This allows for tests to see the close code and reason on the
43+
server side.
44+
4345
"""
4446

45-
def initialize(self, close_future, compression_options=None):
47+
def initialize(self, close_future=None, compression_options=None):
4648
self.close_future = close_future
4749
self.compression_options = compression_options
4850

4951
def get_compression_options(self):
5052
return self.compression_options
5153

5254
def on_close(self):
53-
self.close_future.set_result((self.close_code, self.close_reason))
55+
if self.close_future is not None:
56+
self.close_future.set_result((self.close_code, self.close_reason))
5457

5558

5659
class EchoHandler(TestWebSocketHandler):
@@ -125,10 +128,8 @@ def open(self, arg):
125128

126129

127130
class CoroutineOnMessageHandler(TestWebSocketHandler):
128-
def initialize(self, close_future, compression_options=None):
129-
super(CoroutineOnMessageHandler, self).initialize(
130-
close_future, compression_options
131-
)
131+
def initialize(self, **kwargs):
132+
super(CoroutineOnMessageHandler, self).initialize(**kwargs)
132133
self.sleeping = 0
133134

134135
@gen.coroutine
@@ -191,16 +192,6 @@ def ws_connect(self, path, **kwargs):
191192
)
192193
raise gen.Return(ws)
193194

194-
@gen.coroutine
195-
def close(self, ws):
196-
"""Close a websocket connection and wait for the server side.
197-
198-
If we don't wait here, there are sometimes leak warnings in the
199-
tests.
200-
"""
201-
ws.close()
202-
yield self.close_future
203-
204195

205196
class WebSocketTest(WebSocketBaseTestCase):
206197
def get_app(self):
@@ -296,7 +287,6 @@ def test_websocket_gen(self):
296287
yield ws.write_message("hello")
297288
response = yield ws.read_message()
298289
self.assertEqual(response, "hello")
299-
yield self.close(ws)
300290

301291
def test_websocket_callbacks(self):
302292
websocket_connect(
@@ -317,23 +307,20 @@ def test_binary_message(self):
317307
ws.write_message(b"hello \xe9", binary=True)
318308
response = yield ws.read_message()
319309
self.assertEqual(response, b"hello \xe9")
320-
yield self.close(ws)
321310

322311
@gen_test
323312
def test_unicode_message(self):
324313
ws = yield self.ws_connect("/echo")
325314
ws.write_message(u"hello \u00e9")
326315
response = yield ws.read_message()
327316
self.assertEqual(response, u"hello \u00e9")
328-
yield self.close(ws)
329317

330318
@gen_test
331319
def test_render_message(self):
332320
ws = yield self.ws_connect("/render")
333321
ws.write_message("hello")
334322
response = yield ws.read_message()
335323
self.assertEqual(response, "<b>hello</b>")
336-
yield self.close(ws)
337324

338325
@gen_test
339326
def test_error_in_on_message(self):
@@ -342,7 +329,6 @@ def test_error_in_on_message(self):
342329
with ExpectLog(app_log, "Uncaught exception"):
343330
response = yield ws.read_message()
344331
self.assertIs(response, None)
345-
yield self.close(ws)
346332

347333
@gen_test
348334
def test_websocket_http_fail(self):
@@ -372,7 +358,6 @@ def test_websocket_close_buffered_data(self):
372358
ws.write_message("world")
373359
# Close the underlying stream.
374360
ws.stream.close()
375-
yield self.close_future
376361

377362
@gen_test
378363
def test_websocket_headers(self):
@@ -385,7 +370,6 @@ def test_websocket_headers(self):
385370
)
386371
response = yield ws.read_message()
387372
self.assertEqual(response, "hello")
388-
yield self.close(ws)
389373

390374
@gen_test
391375
def test_websocket_header_echo(self):
@@ -402,7 +386,6 @@ def test_websocket_header_echo(self):
402386
self.assertEqual(
403387
ws.headers.get("X-Extra-Response-Header"), "Extra-Response-Value"
404388
)
405-
yield self.close(ws)
406389

407390
@gen_test
408391
def test_server_close_reason(self):
@@ -472,7 +455,6 @@ def test_check_origin_valid_no_path(self):
472455
ws.write_message("hello")
473456
response = yield ws.read_message()
474457
self.assertEqual(response, "hello")
475-
yield self.close(ws)
476458

477459
@gen_test
478460
def test_check_origin_valid_with_path(self):
@@ -485,7 +467,6 @@ def test_check_origin_valid_with_path(self):
485467
ws.write_message("hello")
486468
response = yield ws.read_message()
487469
self.assertEqual(response, "hello")
488-
yield self.close(ws)
489470

490471
@gen_test
491472
def test_check_origin_invalid_partial_url(self):
@@ -534,15 +515,13 @@ def test_subprotocols(self):
534515
self.assertEqual(ws.selected_subprotocol, "goodproto")
535516
res = yield ws.read_message()
536517
self.assertEqual(res, "subprotocol=goodproto")
537-
yield self.close(ws)
538518

539519
@gen_test
540520
def test_subprotocols_not_offered(self):
541521
ws = yield self.ws_connect("/subprotocol")
542522
self.assertIs(ws.selected_subprotocol, None)
543523
res = yield ws.read_message()
544524
self.assertEqual(res, "subprotocol=None")
545-
yield self.close(ws)
546525

547526
@gen_test
548527
def test_open_coroutine(self):
@@ -552,12 +531,11 @@ def test_open_coroutine(self):
552531
self.message_sent.set()
553532
res = yield ws.read_message()
554533
self.assertEqual(res, "ok")
555-
yield self.close(ws)
556534

557535

558536
class NativeCoroutineOnMessageHandler(TestWebSocketHandler):
559-
def initialize(self, close_future, compression_options=None):
560-
super().initialize(close_future, compression_options)
537+
def initialize(self, **kwargs):
538+
super().initialize(**kwargs)
561539
self.sleeping = 0
562540

563541
async def on_message(self, message):
@@ -571,16 +549,7 @@ async def on_message(self, message):
571549

572550
class WebSocketNativeCoroutineTest(WebSocketBaseTestCase):
573551
def get_app(self):
574-
self.close_future = Future() # type: Future[None]
575-
return Application(
576-
[
577-
(
578-
"/native",
579-
NativeCoroutineOnMessageHandler,
580-
dict(close_future=self.close_future),
581-
)
582-
]
583-
)
552+
return Application([("/native", NativeCoroutineOnMessageHandler)])
584553

585554
@gen_test
586555
def test_native_coroutine(self):
@@ -598,8 +567,6 @@ class CompressionTestMixin(object):
598567
MESSAGE = "Hello world. Testing 123 123"
599568

600569
def get_app(self):
601-
self.close_future = Future() # type: Future[None]
602-
603570
class LimitedHandler(TestWebSocketHandler):
604571
@property
605572
def max_message_size(self):
@@ -613,18 +580,12 @@ def on_message(self, message):
613580
(
614581
"/echo",
615582
EchoHandler,
616-
dict(
617-
close_future=self.close_future,
618-
compression_options=self.get_server_compression_options(),
619-
),
583+
dict(compression_options=self.get_server_compression_options()),
620584
),
621585
(
622586
"/limited",
623587
LimitedHandler,
624-
dict(
625-
close_future=self.close_future,
626-
compression_options=self.get_server_compression_options(),
627-
),
588+
dict(compression_options=self.get_server_compression_options()),
628589
),
629590
]
630591
)
@@ -649,7 +610,6 @@ def test_message_sizes(self):
649610
self.assertEqual(ws.protocol._message_bytes_out, len(self.MESSAGE) * 3)
650611
self.assertEqual(ws.protocol._message_bytes_in, len(self.MESSAGE) * 3)
651612
self.verify_wire_bytes(ws.protocol._wire_bytes_in, ws.protocol._wire_bytes_out)
652-
yield self.close(ws)
653613

654614
@gen_test
655615
def test_size_limit(self):
@@ -665,7 +625,6 @@ def test_size_limit(self):
665625
ws.write_message("a" * 2048)
666626
response = yield ws.read_message()
667627
self.assertIsNone(response)
668-
yield self.close(ws)
669628

670629

671630
class UncompressedTestMixin(CompressionTestMixin):
@@ -743,19 +702,14 @@ class PingHandler(TestWebSocketHandler):
743702
def on_pong(self, data):
744703
self.write_message("got pong")
745704

746-
self.close_future = Future() # type: Future[None]
747-
return Application(
748-
[("/", PingHandler, dict(close_future=self.close_future))],
749-
websocket_ping_interval=0.01,
750-
)
705+
return Application([("/", PingHandler)], websocket_ping_interval=0.01)
751706

752707
@gen_test
753708
def test_server_ping(self):
754709
ws = yield self.ws_connect("/")
755710
for i in range(3):
756711
response = yield ws.read_message()
757712
self.assertEqual(response, "got pong")
758-
yield self.close(ws)
759713
# TODO: test that the connection gets closed if ping responses stop.
760714

761715

@@ -765,16 +719,14 @@ class PingHandler(TestWebSocketHandler):
765719
def on_ping(self, data):
766720
self.write_message("got ping")
767721

768-
self.close_future = Future() # type: Future[None]
769-
return Application([("/", PingHandler, dict(close_future=self.close_future))])
722+
return Application([("/", PingHandler)])
770723

771724
@gen_test
772725
def test_client_ping(self):
773726
ws = yield self.ws_connect("/", ping_interval=0.01)
774727
for i in range(3):
775728
response = yield ws.read_message()
776729
self.assertEqual(response, "got ping")
777-
yield self.close(ws)
778730
# TODO: test that the connection gets closed if ping responses stop.
779731

780732

@@ -784,8 +736,7 @@ class PingHandler(TestWebSocketHandler):
784736
def on_ping(self, data):
785737
self.write_message(data, binary=isinstance(data, bytes))
786738

787-
self.close_future = Future() # type: Future[None]
788-
return Application([("/", PingHandler, dict(close_future=self.close_future))])
739+
return Application([("/", PingHandler)])
789740

790741
@gen_test
791742
def test_manual_ping(self):
@@ -801,16 +752,11 @@ def test_manual_ping(self):
801752
ws.ping(b"binary hello")
802753
resp = yield ws.read_message()
803754
self.assertEqual(resp, b"binary hello")
804-
yield self.close(ws)
805755

806756

807757
class MaxMessageSizeTest(WebSocketBaseTestCase):
808758
def get_app(self):
809-
self.close_future = Future() # type: Future[None]
810-
return Application(
811-
[("/", EchoHandler, dict(close_future=self.close_future))],
812-
websocket_max_message_size=1024,
813-
)
759+
return Application([("/", EchoHandler)], websocket_max_message_size=1024)
814760

815761
@gen_test
816762
def test_large_message(self):

0 commit comments

Comments
 (0)