12
12
13
13
from .rest import ApiException
14
14
15
+ import select
15
16
import certifi
17
+ import time
16
18
import collections
17
- import websocket
19
+ from websocket import WebSocket , ABNF , enableTrace
18
20
import six
19
21
import ssl
20
22
from six .moves .urllib .parse import urlencode
21
23
from six .moves .urllib .parse import quote_plus
22
24
25
+ STDIN_CHANNEL = 0
26
+ STDOUT_CHANNEL = 1
27
+ STDERR_CHANNEL = 2
28
+
23
29
24
30
class WSClient :
25
31
def __init__ (self , configuration , url , headers ):
26
- self .messages = []
27
- self .errors = []
28
- websocket .enableTrace (False )
29
- header = None
32
+ enableTrace (False )
33
+ header = []
34
+ self ._connected = False
35
+ self ._channels = {}
36
+ self ._all = ""
30
37
31
38
# We just need to pass the Authorization, ignore all the other
32
39
# http headers we get from the generated code
33
40
if 'Authorization' in headers :
34
- header = "Authorization: %s" % headers ['Authorization' ]
35
-
36
- self .ws = websocket .WebSocketApp (url ,
37
- on_message = self .on_message ,
38
- on_error = self .on_error ,
39
- on_close = self .on_close ,
40
- header = [header ] if header else None )
41
- self .ws .on_open = self .on_open
41
+ header .append ("Authorization: %s" % headers ['Authorization' ])
42
42
43
43
if url .startswith ('wss://' ) and configuration .verify_ssl :
44
44
ssl_opts = {
@@ -52,30 +52,118 @@ def __init__(self, configuration, url, headers):
52
52
else :
53
53
ssl_opts = {'cert_reqs' : ssl .CERT_NONE }
54
54
55
- self .ws .run_forever (sslopt = ssl_opts )
56
-
57
- def on_message (self , ws , message ):
58
- if message [0 ] == '\x01 ' :
59
- message = message [1 :]
60
- if message :
61
- if six .PY3 and isinstance (message , six .binary_type ):
62
- message = message .decode ('utf-8' )
63
- self .messages .append (message )
55
+ self .sock = WebSocket (sslopt = ssl_opts , skip_utf8_validation = False )
56
+ self .sock .connect (url , header = header )
57
+ self ._connected = True
64
58
65
- def on_error (self , ws , error ):
66
- self .errors .append (error )
59
+ def peek_channel (self , channel , timeout = 0 ):
60
+ self .update (timeout = timeout )
61
+ if channel in self ._channels :
62
+ return self ._channels [channel ]
63
+ return ""
67
64
68
- def on_close (self , ws ):
69
- pass
70
-
71
- def on_open (self , ws ):
72
- pass
65
+ def read_channel (self , channel , timeout = 0 ):
66
+ if channel not in self ._channels :
67
+ ret = self .peek_channel (channel , timeout )
68
+ else :
69
+ ret = self ._channels [channel ]
70
+ if channel in self ._channels :
71
+ del self ._channels [channel ]
72
+ return ret
73
+
74
+ def readline_channel (self , channel , timeout = None ):
75
+ if timeout is None :
76
+ timeout = float ("inf" )
77
+ start = time .time ()
78
+ while self .is_open () and time .time () - start < timeout :
79
+ if channel in self ._channels :
80
+ data = self ._channels [channel ]
81
+ if "\n " in data :
82
+ index = data .find ("\n " )
83
+ ret = data [:index ]
84
+ data = data [index + 1 :]
85
+ if data :
86
+ self ._channels [channel ] = data
87
+ else :
88
+ del self ._channels [channel ]
89
+ return ret
90
+ self .update (timeout = (timeout - time .time () + start ))
91
+
92
+ def write_channel (self , channel , data ):
93
+ self .sock .send (chr (channel ) + data )
94
+
95
+ def peek_stdout (self , timeout = 0 ):
96
+ return self .peek_channel (STDOUT_CHANNEL , timeout = timeout )
97
+
98
+ def read_stdout (self , timeout = None ):
99
+ return self .read_channel (STDOUT_CHANNEL , timeout = timeout )
100
+
101
+ def readline_stdout (self , timeout = None ):
102
+ return self .readline_channel (STDOUT_CHANNEL , timeout = timeout )
103
+
104
+ def peek_stderr (self , timeout = 0 ):
105
+ return self .peek_channel (STDERR_CHANNEL , timeout = timeout )
106
+
107
+ def read_stderr (self , timeout = None ):
108
+ return self .read_channel (STDERR_CHANNEL , timeout = timeout )
109
+
110
+ def readline_stderr (self , timeout = None ):
111
+ return self .readline_channel (STDERR_CHANNEL , timeout = timeout )
112
+
113
+ def read_all (self ):
114
+ out = self ._all
115
+ self ._all = ""
116
+ self ._channels = {}
117
+ return out
118
+
119
+ def is_open (self ):
120
+ return self ._connected
121
+
122
+ def write_stdin (self , data ):
123
+ self .write_channel (STDIN_CHANNEL , data )
124
+
125
+ def update (self , timeout = 0 ):
126
+ if not self .is_open ():
127
+ return
128
+ if not self .sock .connected :
129
+ self ._connected = False
130
+ return
131
+ r , _ , _ = select .select (
132
+ (self .sock .sock , ), (), (), timeout )
133
+ if r :
134
+ op_code , frame = self .sock .recv_data_frame (True )
135
+ if op_code == ABNF .OPCODE_CLOSE :
136
+ self ._connected = False
137
+ return
138
+ elif op_code == ABNF .OPCODE_BINARY or op_code == ABNF .OPCODE_TEXT :
139
+ data = frame .data
140
+ if six .PY3 :
141
+ data = data .decode ("utf-8" )
142
+ self ._all += data
143
+ if len (data ) > 1 :
144
+ channel = ord (data [0 ])
145
+ data = data [1 :]
146
+ if data :
147
+ if channel not in self ._channels :
148
+ self ._channels [channel ] = data
149
+ else :
150
+ self ._channels [channel ] += data
151
+
152
+ def run_forever (self , timeout = None ):
153
+ if timeout :
154
+ start = time .time ()
155
+ while self .is_open () and time .time () - start < timeout :
156
+ self .update (timeout = (timeout - time .time () + start ))
157
+ else :
158
+ while self .is_open ():
159
+ self .update (timeout = None )
73
160
74
161
75
162
WSResponse = collections .namedtuple ('WSResponse' , ['data' ])
76
163
77
164
78
- def GET (configuration , url , query_params , _request_timeout , headers ):
165
+ def GET (configuration , url , query_params , _request_timeout , _preload_content ,
166
+ headers ):
79
167
# switch protocols from http to websocket
80
168
url = url .replace ('http://' , 'ws://' )
81
169
url = url .replace ('https://' , 'wss://' )
@@ -105,10 +193,11 @@ def GET(configuration, url, query_params, _request_timeout, headers):
105
193
else :
106
194
url += '&command=' + quote_plus (commands )
107
195
108
- client = WSClient (configuration , url , headers )
109
- if client .errors :
110
- raise ApiException (
111
- status = 0 ,
112
- reason = '\n ' .join ([str (error ) for error in client .errors ])
113
- )
114
- return WSResponse ('%s' % '' .join (client .messages ))
196
+ try :
197
+ client = WSClient (configuration , url , headers )
198
+ if not _preload_content :
199
+ return client
200
+ client .run_forever (timeout = _request_timeout )
201
+ return WSResponse ('%s' % '' .join (client .read_all ()))
202
+ except (Exception , KeyboardInterrupt , SystemExit ) as e :
203
+ raise ApiException (status = 0 , reason = str (e ))
0 commit comments