12
12
13
13
from .rest import ApiException
14
14
15
+ import select
15
16
import certifi
17
+ import time
16
18
import collections
17
- import websocket
19
+ from websocket import WebSocket , ABNF , enableTrace
18
20
import six
19
21
import ssl
20
22
from six .moves .urllib .parse import urlencode
21
23
from six .moves .urllib .parse import quote_plus
22
-
24
+ import socket
23
25
24
26
class WSClient :
25
27
def __init__ (self , configuration , url , headers ):
26
- self .messages = []
27
- self .errors = []
28
- websocket .enableTrace (False )
29
- header = None
28
+ enableTrace (False )
29
+ header = []
30
+ self ._connected = False
31
+ self ._stdout = ""
32
+ self ._stderr = ""
33
+ self ._all = ""
30
34
31
35
# We just need to pass the Authorization, ignore all the other
32
36
# http headers we get from the generated code
33
37
if 'Authorization' in headers :
34
- header = "Authorization: %s" % headers ['Authorization' ]
35
-
36
- self .ws = websocket .WebSocketApp (url ,
37
- on_message = self .on_message ,
38
- on_error = self .on_error ,
39
- on_close = self .on_close ,
40
- header = [header ] if header else None )
41
- self .ws .on_open = self .on_open
38
+ header .append ("Authorization: %s" % headers ['Authorization' ])
42
39
43
40
if url .startswith ('wss://' ) and configuration .verify_ssl :
44
41
ssl_opts = {
@@ -52,30 +49,87 @@ def __init__(self, configuration, url, headers):
52
49
else :
53
50
ssl_opts = {'cert_reqs' : ssl .CERT_NONE }
54
51
55
- self .ws .run_forever (sslopt = ssl_opts )
56
-
57
- def on_message (self , ws , message ):
58
- if message [0 ] == '\x01 ' :
59
- message = message [1 :]
60
- if message :
61
- if six .PY3 and isinstance (message , six .binary_type ):
62
- message = message .decode ('utf-8' )
63
- self .messages .append (message )
64
-
65
- def on_error (self , ws , error ):
66
- self .errors .append (error )
67
-
68
- def on_close (self , ws ):
69
- pass
70
-
71
- def on_open (self , ws ):
72
- pass
52
+ self .sock = WebSocket (sslopt = ssl_opts , skip_utf8_validation = False )
53
+ self .sock .connect (url , header = header )
54
+ self ._connected = True
55
+
56
+ def peek_stdout (self ):
57
+ self .update ()
58
+ return self ._stdout
59
+
60
+ def read_stdout (self ):
61
+ if not self ._stdout :
62
+ self .update (timeout = None )
63
+ ret = self ._stdout
64
+ self ._stdout = ""
65
+ return ret
66
+
67
+ def peek_stderr (self ):
68
+ self .update ()
69
+ return self ._stderr
70
+
71
+ def read_stderr (self ):
72
+ if not self ._stderr :
73
+ self .update (timeout = None )
74
+ ret = self ._stderr
75
+ self ._stderr = ""
76
+ return ret
77
+
78
+ def read_all (self ):
79
+ out = self ._all
80
+ self ._all = ""
81
+ self ._stdout = ""
82
+ self ._stderr = ""
83
+ return out
84
+
85
+ def is_open (self ):
86
+ return self ._connected
87
+
88
+ # TODO: This method does not seem to work.
89
+ def write_stdin (self , data ):
90
+ self .sock .send (data )
91
+
92
+ def update (self , timeout = 0 ):
93
+ if not self .is_open ():
94
+ return
95
+ if not self .sock .connected :
96
+ self ._connected = False
97
+ return
98
+ r , _ , _ = select .select (
99
+ (self .sock .sock , ), (), (), timeout )
100
+ if r :
101
+ op_code , frame = self .sock .recv_data_frame (True )
102
+ if op_code == ABNF .OPCODE_CLOSE :
103
+ self ._connected = False
104
+ return
105
+ elif op_code == ABNF .OPCODE_BINARY or op_code == ABNF .OPCODE_TEXT :
106
+ data = frame .data
107
+ if six .PY3 and op_code == ABNF .OPCODE_TEXT :
108
+ data = data .decode ("utf-8" )
109
+ if data [0 ] == '\x01 ' :
110
+ data = data [1 :]
111
+ if data :
112
+ self ._all += data
113
+ if data [0 ] == '\x02 ' :
114
+ self ._stderr += data [1 :]
115
+ else :
116
+ self ._stdout += data
117
+
118
+ def run_forever (self , timeout = None ):
119
+ if timeout :
120
+ start = time .time ()
121
+ while self .is_open () and time .time () - start < timeout :
122
+ self .update (timeout = (timeout - time .time () + start ))
123
+ else :
124
+ while self .is_open ():
125
+ self .update (timeout = None )
73
126
74
127
75
128
WSResponse = collections .namedtuple ('WSResponse' , ['data' ])
76
129
77
130
78
- def GET (configuration , url , query_params , _request_timeout , headers ):
131
+ def GET (configuration , url , query_params , _request_timeout , _preload_content ,
132
+ headers ):
79
133
# switch protocols from http to websocket
80
134
url = url .replace ('http://' , 'ws://' )
81
135
url = url .replace ('https://' , 'wss://' )
@@ -105,10 +159,12 @@ def GET(configuration, url, query_params, _request_timeout, headers):
105
159
else :
106
160
url += '&command=' + quote_plus (commands )
107
161
108
- client = WSClient (configuration , url , headers )
109
- if client .errors :
110
- raise ApiException (
111
- status = 0 ,
112
- reason = '\n ' .join ([str (error ) for error in client .errors ])
113
- )
162
+ try :
163
+ client = WSClient (configuration , url , headers )
164
+ if not _preload_content :
165
+ return client
166
+ client .run_forever (timeout = _request_timeout )
167
+ return WSResponse ('%s' % '' .join (client .read_all ()))
168
+ except (Exception , KeyboardInterrupt , SystemExit ) as e :
169
+ raise ApiException (status = 0 , reason = str (e ))
114
170
return WSResponse ('%s' % '' .join (client .messages ))
0 commit comments