39
39
class TestWebSocketHandler (WebSocketHandler ):
40
40
"""Base class for testing handlers that exposes the on_close event.
41
41
42
- This allows for deterministic cleanup of the associated socket.
42
+ This allows for tests to see the close code and reason on the
43
+ server side.
44
+
43
45
"""
44
46
45
- def initialize (self , close_future , compression_options = None ):
47
+ def initialize (self , close_future = None , compression_options = None ):
46
48
self .close_future = close_future
47
49
self .compression_options = compression_options
48
50
49
51
def get_compression_options (self ):
50
52
return self .compression_options
51
53
52
54
def on_close (self ):
53
- self .close_future .set_result ((self .close_code , self .close_reason ))
55
+ if self .close_future is not None :
56
+ self .close_future .set_result ((self .close_code , self .close_reason ))
54
57
55
58
56
59
class EchoHandler (TestWebSocketHandler ):
@@ -125,10 +128,8 @@ def open(self, arg):
125
128
126
129
127
130
class CoroutineOnMessageHandler (TestWebSocketHandler ):
128
- def initialize (self , close_future , compression_options = None ):
129
- super (CoroutineOnMessageHandler , self ).initialize (
130
- close_future , compression_options
131
- )
131
+ def initialize (self , ** kwargs ):
132
+ super (CoroutineOnMessageHandler , self ).initialize (** kwargs )
132
133
self .sleeping = 0
133
134
134
135
@gen .coroutine
@@ -191,16 +192,6 @@ def ws_connect(self, path, **kwargs):
191
192
)
192
193
raise gen .Return (ws )
193
194
194
- @gen .coroutine
195
- def close (self , ws ):
196
- """Close a websocket connection and wait for the server side.
197
-
198
- If we don't wait here, there are sometimes leak warnings in the
199
- tests.
200
- """
201
- ws .close ()
202
- yield self .close_future
203
-
204
195
205
196
class WebSocketTest (WebSocketBaseTestCase ):
206
197
def get_app (self ):
@@ -296,7 +287,6 @@ def test_websocket_gen(self):
296
287
yield ws .write_message ("hello" )
297
288
response = yield ws .read_message ()
298
289
self .assertEqual (response , "hello" )
299
- yield self .close (ws )
300
290
301
291
def test_websocket_callbacks (self ):
302
292
websocket_connect (
@@ -317,23 +307,20 @@ def test_binary_message(self):
317
307
ws .write_message (b"hello \xe9 " , binary = True )
318
308
response = yield ws .read_message ()
319
309
self .assertEqual (response , b"hello \xe9 " )
320
- yield self .close (ws )
321
310
322
311
@gen_test
323
312
def test_unicode_message (self ):
324
313
ws = yield self .ws_connect ("/echo" )
325
314
ws .write_message (u"hello \u00e9 " )
326
315
response = yield ws .read_message ()
327
316
self .assertEqual (response , u"hello \u00e9 " )
328
- yield self .close (ws )
329
317
330
318
@gen_test
331
319
def test_render_message (self ):
332
320
ws = yield self .ws_connect ("/render" )
333
321
ws .write_message ("hello" )
334
322
response = yield ws .read_message ()
335
323
self .assertEqual (response , "<b>hello</b>" )
336
- yield self .close (ws )
337
324
338
325
@gen_test
339
326
def test_error_in_on_message (self ):
@@ -342,7 +329,6 @@ def test_error_in_on_message(self):
342
329
with ExpectLog (app_log , "Uncaught exception" ):
343
330
response = yield ws .read_message ()
344
331
self .assertIs (response , None )
345
- yield self .close (ws )
346
332
347
333
@gen_test
348
334
def test_websocket_http_fail (self ):
@@ -372,7 +358,6 @@ def test_websocket_close_buffered_data(self):
372
358
ws .write_message ("world" )
373
359
# Close the underlying stream.
374
360
ws .stream .close ()
375
- yield self .close_future
376
361
377
362
@gen_test
378
363
def test_websocket_headers (self ):
@@ -385,7 +370,6 @@ def test_websocket_headers(self):
385
370
)
386
371
response = yield ws .read_message ()
387
372
self .assertEqual (response , "hello" )
388
- yield self .close (ws )
389
373
390
374
@gen_test
391
375
def test_websocket_header_echo (self ):
@@ -402,7 +386,6 @@ def test_websocket_header_echo(self):
402
386
self .assertEqual (
403
387
ws .headers .get ("X-Extra-Response-Header" ), "Extra-Response-Value"
404
388
)
405
- yield self .close (ws )
406
389
407
390
@gen_test
408
391
def test_server_close_reason (self ):
@@ -472,7 +455,6 @@ def test_check_origin_valid_no_path(self):
472
455
ws .write_message ("hello" )
473
456
response = yield ws .read_message ()
474
457
self .assertEqual (response , "hello" )
475
- yield self .close (ws )
476
458
477
459
@gen_test
478
460
def test_check_origin_valid_with_path (self ):
@@ -485,7 +467,6 @@ def test_check_origin_valid_with_path(self):
485
467
ws .write_message ("hello" )
486
468
response = yield ws .read_message ()
487
469
self .assertEqual (response , "hello" )
488
- yield self .close (ws )
489
470
490
471
@gen_test
491
472
def test_check_origin_invalid_partial_url (self ):
@@ -534,15 +515,13 @@ def test_subprotocols(self):
534
515
self .assertEqual (ws .selected_subprotocol , "goodproto" )
535
516
res = yield ws .read_message ()
536
517
self .assertEqual (res , "subprotocol=goodproto" )
537
- yield self .close (ws )
538
518
539
519
@gen_test
540
520
def test_subprotocols_not_offered (self ):
541
521
ws = yield self .ws_connect ("/subprotocol" )
542
522
self .assertIs (ws .selected_subprotocol , None )
543
523
res = yield ws .read_message ()
544
524
self .assertEqual (res , "subprotocol=None" )
545
- yield self .close (ws )
546
525
547
526
@gen_test
548
527
def test_open_coroutine (self ):
@@ -552,12 +531,11 @@ def test_open_coroutine(self):
552
531
self .message_sent .set ()
553
532
res = yield ws .read_message ()
554
533
self .assertEqual (res , "ok" )
555
- yield self .close (ws )
556
534
557
535
558
536
class NativeCoroutineOnMessageHandler (TestWebSocketHandler ):
559
- def initialize (self , close_future , compression_options = None ):
560
- super ().initialize (close_future , compression_options )
537
+ def initialize (self , ** kwargs ):
538
+ super ().initialize (** kwargs )
561
539
self .sleeping = 0
562
540
563
541
async def on_message (self , message ):
@@ -571,16 +549,7 @@ async def on_message(self, message):
571
549
572
550
class WebSocketNativeCoroutineTest (WebSocketBaseTestCase ):
573
551
def get_app (self ):
574
- self .close_future = Future () # type: Future[None]
575
- return Application (
576
- [
577
- (
578
- "/native" ,
579
- NativeCoroutineOnMessageHandler ,
580
- dict (close_future = self .close_future ),
581
- )
582
- ]
583
- )
552
+ return Application ([("/native" , NativeCoroutineOnMessageHandler )])
584
553
585
554
@gen_test
586
555
def test_native_coroutine (self ):
@@ -598,8 +567,6 @@ class CompressionTestMixin(object):
598
567
MESSAGE = "Hello world. Testing 123 123"
599
568
600
569
def get_app (self ):
601
- self .close_future = Future () # type: Future[None]
602
-
603
570
class LimitedHandler (TestWebSocketHandler ):
604
571
@property
605
572
def max_message_size (self ):
@@ -613,18 +580,12 @@ def on_message(self, message):
613
580
(
614
581
"/echo" ,
615
582
EchoHandler ,
616
- dict (
617
- close_future = self .close_future ,
618
- compression_options = self .get_server_compression_options (),
619
- ),
583
+ dict (compression_options = self .get_server_compression_options ()),
620
584
),
621
585
(
622
586
"/limited" ,
623
587
LimitedHandler ,
624
- dict (
625
- close_future = self .close_future ,
626
- compression_options = self .get_server_compression_options (),
627
- ),
588
+ dict (compression_options = self .get_server_compression_options ()),
628
589
),
629
590
]
630
591
)
@@ -649,7 +610,6 @@ def test_message_sizes(self):
649
610
self .assertEqual (ws .protocol ._message_bytes_out , len (self .MESSAGE ) * 3 )
650
611
self .assertEqual (ws .protocol ._message_bytes_in , len (self .MESSAGE ) * 3 )
651
612
self .verify_wire_bytes (ws .protocol ._wire_bytes_in , ws .protocol ._wire_bytes_out )
652
- yield self .close (ws )
653
613
654
614
@gen_test
655
615
def test_size_limit (self ):
@@ -665,7 +625,6 @@ def test_size_limit(self):
665
625
ws .write_message ("a" * 2048 )
666
626
response = yield ws .read_message ()
667
627
self .assertIsNone (response )
668
- yield self .close (ws )
669
628
670
629
671
630
class UncompressedTestMixin (CompressionTestMixin ):
@@ -743,19 +702,14 @@ class PingHandler(TestWebSocketHandler):
743
702
def on_pong (self , data ):
744
703
self .write_message ("got pong" )
745
704
746
- self .close_future = Future () # type: Future[None]
747
- return Application (
748
- [("/" , PingHandler , dict (close_future = self .close_future ))],
749
- websocket_ping_interval = 0.01 ,
750
- )
705
+ return Application ([("/" , PingHandler )], websocket_ping_interval = 0.01 )
751
706
752
707
@gen_test
753
708
def test_server_ping (self ):
754
709
ws = yield self .ws_connect ("/" )
755
710
for i in range (3 ):
756
711
response = yield ws .read_message ()
757
712
self .assertEqual (response , "got pong" )
758
- yield self .close (ws )
759
713
# TODO: test that the connection gets closed if ping responses stop.
760
714
761
715
@@ -765,16 +719,14 @@ class PingHandler(TestWebSocketHandler):
765
719
def on_ping (self , data ):
766
720
self .write_message ("got ping" )
767
721
768
- self .close_future = Future () # type: Future[None]
769
- return Application ([("/" , PingHandler , dict (close_future = self .close_future ))])
722
+ return Application ([("/" , PingHandler )])
770
723
771
724
@gen_test
772
725
def test_client_ping (self ):
773
726
ws = yield self .ws_connect ("/" , ping_interval = 0.01 )
774
727
for i in range (3 ):
775
728
response = yield ws .read_message ()
776
729
self .assertEqual (response , "got ping" )
777
- yield self .close (ws )
778
730
# TODO: test that the connection gets closed if ping responses stop.
779
731
780
732
@@ -784,8 +736,7 @@ class PingHandler(TestWebSocketHandler):
784
736
def on_ping (self , data ):
785
737
self .write_message (data , binary = isinstance (data , bytes ))
786
738
787
- self .close_future = Future () # type: Future[None]
788
- return Application ([("/" , PingHandler , dict (close_future = self .close_future ))])
739
+ return Application ([("/" , PingHandler )])
789
740
790
741
@gen_test
791
742
def test_manual_ping (self ):
@@ -801,16 +752,11 @@ def test_manual_ping(self):
801
752
ws .ping (b"binary hello" )
802
753
resp = yield ws .read_message ()
803
754
self .assertEqual (resp , b"binary hello" )
804
- yield self .close (ws )
805
755
806
756
807
757
class MaxMessageSizeTest (WebSocketBaseTestCase ):
808
758
def get_app (self ):
809
- self .close_future = Future () # type: Future[None]
810
- return Application (
811
- [("/" , EchoHandler , dict (close_future = self .close_future ))],
812
- websocket_max_message_size = 1024 ,
813
- )
759
+ return Application ([("/" , EchoHandler )], websocket_max_message_size = 1024 )
814
760
815
761
@gen_test
816
762
def test_large_message (self ):
0 commit comments