Skip to content
This repository was archived by the owner on Mar 13, 2022. It is now read-only.

Commit 7bf04b3

Browse files
committed
Rework how the PortForward._proxy thread determines when and how to terminate.
1 parent 72e3725 commit 7bf04b3

File tree

1 file changed

+78
-73
lines changed

1 file changed

+78
-73
lines changed

stream/ws_client.py

+78-73
Original file line numberDiff line numberDiff line change
@@ -238,33 +238,51 @@ def __init__(self, websocket, ports):
238238

239239
self.websocket = websocket
240240
self.local_ports = {}
241-
for ix, local_remote in enumerate(ports):
242-
self.local_ports[local_remote[0]] = self._Port(ix, local_remote[1])
241+
for ix, port_number in enumerate(ports):
242+
self.local_ports[port_number] = self._Port(ix, port_number)
243+
# There is a thread run per PortForward instance which performs the translation between the
244+
# raw socket data sent by the python application and the websocket protocol. This thread
245+
# terminates after either side has closed all ports, and after flushing all pending data.
243246
threading.Thread(
244-
name="Kubernetes port forward proxy", target=self._proxy, daemon=True
247+
name="Kubernetes port forward proxy: %s" % ', '.join([str(port) for port in ports]),
248+
target=self._proxy,
249+
daemon=True
245250
).start()
246251

247-
def socket(self, local_number):
248-
if local_number not in self.local_ports:
252+
def socket(self, port_number):
253+
if port_number not in self.local_ports:
249254
raise ValueError("Invalid port number")
250-
return self.local_ports[local_number].socket
255+
return self.local_ports[port_number].socket
251256

252-
def error(self, local_number):
253-
if local_number not in self.local_ports:
257+
def error(self, port_number):
258+
if port_number not in self.local_ports:
254259
raise ValueError("Invalid port number")
255-
return self.local_ports[local_number].error
260+
return self.local_ports[port_number].error
256261

257262
def close(self):
258263
for port in self.local_ports.values():
259264
port.socket.close()
260265

261266
class _Port:
262-
def __init__(self, ix, remote_number):
263-
self.remote_number = remote_number
267+
def __init__(self, ix, port_number):
268+
# The remote port number
269+
self.port_number = port_number
270+
# The websocket channel byte number for this port
264271
self.channel = bytes([ix * 2])
272+
# A socket pair is created to provide a means of translating the data flow
273+
# between the python application and the kubernetes websocket. The self.python
274+
# half of the socket pair is used by the _proxy method to receive and send data
275+
# to the running python application.
265276
s, self.python = socket.socketpair(socket.AF_UNIX, socket.SOCK_STREAM)
277+
# The self.socket half of the pair is used by the python application to send
278+
# and receive data to the eventual pod port. It is wrapped in the _Socket class
279+
# because a socket pair is an AF_UNIX socket, not a AF_NET socket. This allows
280+
# intercepting setting AF_INET socket options that would error against an AD_UNIX
281+
# socket.
266282
self.socket = self._Socket(s)
283+
# Data accumulated from the websocket to be sent to the python application.
267284
self.data = b''
285+
# All data sent from kubernetes on the port error channel.
268286
self.error = None
269287

270288
class _Socket:
@@ -285,42 +303,44 @@ def setsockopt(self, level, optname, value):
285303
def _proxy(self):
286304
channel_ports = []
287305
channel_initialized = []
288-
python_ports = {}
289-
rlist = []
306+
local_ports = {}
290307
for port in self.local_ports.values():
291308
# Setup the data channel for this port number
292309
channel_ports.append(port)
293310
channel_initialized.append(False)
294311
# Setup the error channel for this port number
295312
channel_ports.append(port)
296313
channel_initialized.append(False)
297-
python_ports[port.python] = port
298-
rlist.append(port.python)
299-
rlist.append(self.websocket.sock)
314+
port.python.setblocking(True)
315+
local_ports[port.python] = port
316+
# The data to send on the websocket socket
300317
kubernetes_data = b''
301318
while True:
302-
wlist = []
319+
rlist = [] # List of sockets to read from
320+
wlist = [] # List of sockets to write to
321+
if self.websocket.connected:
322+
rlist.append(self.websocket)
323+
if kubernetes_data:
324+
wlist.append(self.websocket)
325+
all_closed = True
303326
for port in self.local_ports.values():
304-
if port.data:
305-
wlist.append(port.python)
306-
if kubernetes_data:
307-
wlist.append(self.websocket.sock)
327+
if port.python.fileno() != -1:
328+
if port.data:
329+
wlist.append(port.python)
330+
all_closed = False
331+
else:
332+
if self.websocket.connected:
333+
rlist.append(port.python)
334+
all_closed = False
335+
else:
336+
port.python.close()
337+
if all_closed and (not self.websocket.connected or not kubernetes_data):
338+
self.websocket.close()
339+
return
308340
r, w, _ = select.select(rlist, wlist, [])
309-
for s in w:
310-
if s == self.websocket.sock:
311-
sent = self.websocket.sock.send(kubernetes_data)
312-
kubernetes_data = kubernetes_data[sent:]
313-
else:
314-
port = python_ports[s]
315-
sent = port.python.send(port.data)
316-
port.data = port.data[sent:]
317-
for s in r:
318-
if s == self.websocket.sock:
341+
for sock in r:
342+
if sock == self.websocket:
319343
opcode, frame = self.websocket.recv_data_frame(True)
320-
if opcode == ABNF.OPCODE_CLOSE:
321-
for port in self.local_ports.values():
322-
port.python.close()
323-
return
324344
if opcode == ABNF.OPCODE_BINARY:
325345
if not frame.data:
326346
raise RuntimeError("Unexpected frame data size")
@@ -341,27 +361,32 @@ def _proxy(self):
341361
"Unexpected initial channel frame data size"
342362
)
343363
port_number = frame.data[1] + (frame.data[2] * 256)
344-
if port_number != port.remote_number:
364+
if port_number != port.port_number:
345365
raise RuntimeError(
346366
"Unexpected port number in initial channel frame: " + str(port_number)
347367
)
348368
channel_initialized[channel] = True
349-
elif opcode not in (ABNF.OPCODE_PING, ABNF.OPCODE_PONG):
369+
elif opcode not in (ABNF.OPCODE_PING, ABNF.OPCODE_PONG, ABNF.OPCODE_CLOSE):
350370
raise RuntimeError("Unexpected websocket opcode: " + str(opcode))
351371
else:
352-
port = python_ports[s]
372+
port = local_ports[sock]
353373
data = port.python.recv(1024 * 1024)
354374
if data:
355375
kubernetes_data += ABNF.create_frame(
356376
port.channel + data,
357377
ABNF.OPCODE_BINARY,
358378
).format()
359379
else:
360-
port.python.close()
361-
rlist.remove(s)
362-
if len(rlist) == 1:
363-
self.websocket.close()
364-
return
380+
if not port.data:
381+
port.python.close()
382+
for sock in w:
383+
if sock == self.websocket:
384+
sent = self.websocket.sock.send(kubernetes_data)
385+
kubernetes_data = kubernetes_data[sent:]
386+
else:
387+
port = local_ports[sock]
388+
sent = port.python.send(port.data)
389+
port.data = port.data[sent:]
365390

366391

367392
def get_websocket_url(url, query_params=None):
@@ -451,38 +476,18 @@ def portforward_call(configuration, _method, url, **kwargs):
451476
query_params = kwargs.get("query_params")
452477

453478
ports = []
454-
for ix in range(len(query_params)):
455-
if query_params[ix][0] == 'ports':
456-
remote_ports = []
457-
for port in query_params[ix][1].split(','):
479+
for param, value in query_params:
480+
if param == 'ports':
481+
for port in value.split(','):
458482
try:
459-
local_remote = port.split(':')
460-
if len(local_remote) > 2:
461-
raise ValueError
462-
if len(local_remote) == 1:
463-
local_remote[0] = int(local_remote[0])
464-
if not (0 < local_remote[0] < 65536):
465-
raise ValueError
466-
local_remote.append(local_remote[0])
467-
elif len(local_remote) == 2:
468-
if local_remote[0]:
469-
local_remote[0] = int(local_remote[0])
470-
if not (0 <= local_remote[0] < 65536):
471-
raise ValueError
472-
else:
473-
local_remote[0] = 0
474-
local_remote[1] = int(local_remote[1])
475-
if not (0 < local_remote[1] < 65536):
476-
raise ValueError
477-
if not local_remote[0]:
478-
local_remote[0] = len(ports) + 1
479-
else:
480-
raise ValueError
481-
ports.append(local_remote)
482-
remote_ports.append(str(local_remote[1]))
483+
port_number = int(port)
483484
except ValueError:
484-
raise ApiValueError("Invalid port number `" + port + "`")
485-
query_params[ix] = ('ports', ','.join(remote_ports))
485+
raise ApiValueError("Invalid port number: %s" % port)
486+
if not (0 < port_number < 65536):
487+
raise ApiValueError("Port number must be between 0 and 65536: %s" % port)
488+
if port_number in ports:
489+
raise ApiValueError("Duplicate port numbers: %s" % port)
490+
ports.append(port_number)
486491
if not ports:
487492
raise ApiValueError("Missing required parameter `ports`")
488493

0 commit comments

Comments
 (0)