12
12
13
13
from .rest import ApiException
14
14
15
+ import select
15
16
import certifi
17
+ import time
16
18
import collections
17
- import websocket
19
+ from websocket import WebSocket , ABNF , enableTrace
18
20
import six
19
21
import ssl
20
22
from six .moves .urllib .parse import urlencode
21
23
from six .moves .urllib .parse import quote_plus
22
24
25
+ STDIN_CHANNEL = 0
26
+ STDOUT_CHANNEL = 1
27
+ STDERR_CHANNEL = 2
28
+
23
29
24
30
class WSClient :
25
31
def __init__ (self , configuration , url , headers ):
26
- self .messages = []
27
- self .errors = []
28
- websocket .enableTrace (False )
29
- header = None
32
+ """A websocket client with support for channels.
33
+
34
+ Exec command uses different channels for different streams. for
35
+ example, 0 is stdin, 1 is stdout and 2 is stderr. Some other API calls
36
+ like port forwarding can forward different pods' streams to different
37
+ channels.
38
+ """
39
+ enableTrace (False )
40
+ header = []
41
+ self ._connected = False
42
+ self ._channels = {}
43
+ self ._all = ""
30
44
31
45
# We just need to pass the Authorization, ignore all the other
32
46
# http headers we get from the generated code
33
- if 'Authorization' in headers :
34
- header = "Authorization: %s" % headers ['Authorization' ]
35
-
36
- self .ws = websocket .WebSocketApp (url ,
37
- on_message = self .on_message ,
38
- on_error = self .on_error ,
39
- on_close = self .on_close ,
40
- header = [header ] if header else None )
41
- self .ws .on_open = self .on_open
47
+ if headers and 'authorization' in headers :
48
+ header .append ("authorization: %s" % headers ['authorization' ])
42
49
43
50
if url .startswith ('wss://' ) and configuration .verify_ssl :
44
51
ssl_opts = {
@@ -52,30 +59,145 @@ def __init__(self, configuration, url, headers):
52
59
else :
53
60
ssl_opts = {'cert_reqs' : ssl .CERT_NONE }
54
61
55
- self .ws .run_forever (sslopt = ssl_opts )
56
-
57
- def on_message (self , ws , message ):
58
- if message [0 ] == '\x01 ' :
59
- message = message [1 :]
60
- if message :
61
- if six .PY3 and isinstance (message , six .binary_type ):
62
- message = message .decode ('utf-8' )
63
- self .messages .append (message )
64
-
65
- def on_error (self , ws , error ):
66
- self .errors .append (error )
67
-
68
- def on_close (self , ws ):
69
- pass
70
-
71
- def on_open (self , ws ):
72
- pass
62
+ self .sock = WebSocket (sslopt = ssl_opts , skip_utf8_validation = False )
63
+ self .sock .connect (url , header = header )
64
+ self ._connected = True
65
+
66
+ def peek_channel (self , channel , timeout = 0 ):
67
+ """Peek a channel and return part of the input,
68
+ empty string otherwise."""
69
+ self .update (timeout = timeout )
70
+ if channel in self ._channels :
71
+ return self ._channels [channel ]
72
+ return ""
73
+
74
+ def read_channel (self , channel , timeout = 0 ):
75
+ """Read data from a channel."""
76
+ if channel not in self ._channels :
77
+ ret = self .peek_channel (channel , timeout )
78
+ else :
79
+ ret = self ._channels [channel ]
80
+ if channel in self ._channels :
81
+ del self ._channels [channel ]
82
+ return ret
83
+
84
+ def readline_channel (self , channel , timeout = None ):
85
+ """Read a line from a channel."""
86
+ if timeout is None :
87
+ timeout = float ("inf" )
88
+ start = time .time ()
89
+ while self .is_open () and time .time () - start < timeout :
90
+ if channel in self ._channels :
91
+ data = self ._channels [channel ]
92
+ if "\n " in data :
93
+ index = data .find ("\n " )
94
+ ret = data [:index ]
95
+ data = data [index + 1 :]
96
+ if data :
97
+ self ._channels [channel ] = data
98
+ else :
99
+ del self ._channels [channel ]
100
+ return ret
101
+ self .update (timeout = (timeout - time .time () + start ))
102
+
103
+ def write_channel (self , channel , data ):
104
+ """Write data to a channel."""
105
+ self .sock .send (chr (channel ) + data )
106
+
107
+ def peek_stdout (self , timeout = 0 ):
108
+ """Same as peek_channel with channel=1."""
109
+ return self .peek_channel (STDOUT_CHANNEL , timeout = timeout )
110
+
111
+ def read_stdout (self , timeout = None ):
112
+ """Same as read_channel with channel=1."""
113
+ return self .read_channel (STDOUT_CHANNEL , timeout = timeout )
114
+
115
+ def readline_stdout (self , timeout = None ):
116
+ """Same as readline_channel with channel=1."""
117
+ return self .readline_channel (STDOUT_CHANNEL , timeout = timeout )
118
+
119
+ def peek_stderr (self , timeout = 0 ):
120
+ """Same as peek_channel with channel=2."""
121
+ return self .peek_channel (STDERR_CHANNEL , timeout = timeout )
122
+
123
+ def read_stderr (self , timeout = None ):
124
+ """Same as read_channel with channel=2."""
125
+ return self .read_channel (STDERR_CHANNEL , timeout = timeout )
126
+
127
+ def readline_stderr (self , timeout = None ):
128
+ """Same as readline_channel with channel=2."""
129
+ return self .readline_channel (STDERR_CHANNEL , timeout = timeout )
130
+
131
+ def read_all (self ):
132
+ """Read all of the inputs with the same order they recieved. The channel
133
+ information would be part of the string. This is useful for
134
+ non-interactive call where a set of command passed to the API call and
135
+ their result is needed after the call is concluded.
136
+
137
+ TODO: Maybe we can process this and return a more meaningful map with
138
+ channels mapped for each input.
139
+ """
140
+ out = self ._all
141
+ self ._all = ""
142
+ self ._channels = {}
143
+ return out
144
+
145
+ def is_open (self ):
146
+ """True if the connection is still alive."""
147
+ return self ._connected
148
+
149
+ def write_stdin (self , data ):
150
+ """The same as write_channel with channel=0."""
151
+ self .write_channel (STDIN_CHANNEL , data )
152
+
153
+ def update (self , timeout = 0 ):
154
+ """Update channel buffers with at most one complete frame of input."""
155
+ if not self .is_open ():
156
+ return
157
+ if not self .sock .connected :
158
+ self ._connected = False
159
+ return
160
+ r , _ , _ = select .select (
161
+ (self .sock .sock , ), (), (), timeout )
162
+ if r :
163
+ op_code , frame = self .sock .recv_data_frame (True )
164
+ if op_code == ABNF .OPCODE_CLOSE :
165
+ self ._connected = False
166
+ return
167
+ elif op_code == ABNF .OPCODE_BINARY or op_code == ABNF .OPCODE_TEXT :
168
+ data = frame .data
169
+ if six .PY3 :
170
+ data = data .decode ("utf-8" )
171
+ self ._all += data
172
+ if len (data ) > 1 :
173
+ channel = ord (data [0 ])
174
+ data = data [1 :]
175
+ if data :
176
+ if channel not in self ._channels :
177
+ self ._channels [channel ] = data
178
+ else :
179
+ self ._channels [channel ] += data
180
+
181
+ def run_forever (self , timeout = None ):
182
+ """Wait till connection is closed or timeout reached. Buffer any input
183
+ received during this time."""
184
+ if timeout :
185
+ start = time .time ()
186
+ while self .is_open () and time .time () - start < timeout :
187
+ self .update (timeout = (timeout - time .time () + start ))
188
+ else :
189
+ while self .is_open ():
190
+ self .update (timeout = None )
73
191
74
192
75
193
WSResponse = collections .namedtuple ('WSResponse' , ['data' ])
76
194
77
195
78
- def GET (configuration , url , query_params , _request_timeout , headers ):
196
+ def websocket_call (configuration , url , query_params , _request_timeout ,
197
+ _preload_content , headers ):
198
+ """An internal function to be called in api-client when a websocket
199
+ connection is required."""
200
+
79
201
# switch protocols from http to websocket
80
202
url = url .replace ('http://' , 'ws://' )
81
203
url = url .replace ('https://' , 'wss://' )
@@ -105,10 +227,11 @@ def GET(configuration, url, query_params, _request_timeout, headers):
105
227
else :
106
228
url += '&command=' + quote_plus (commands )
107
229
108
- client = WSClient (configuration , url , headers )
109
- if client .errors :
110
- raise ApiException (
111
- status = 0 ,
112
- reason = '\n ' .join ([str (error ) for error in client .errors ])
113
- )
114
- return WSResponse ('%s' % '' .join (client .messages ))
230
+ try :
231
+ client = WSClient (configuration , url , headers )
232
+ if not _preload_content :
233
+ return client
234
+ client .run_forever (timeout = _request_timeout )
235
+ return WSResponse ('%s' % '' .join (client .read_all ()))
236
+ except (Exception , KeyboardInterrupt , SystemExit ) as e :
237
+ raise ApiException (status = 0 , reason = str (e ))
0 commit comments