Skip to content
This repository was archived by the owner on Mar 13, 2022. It is now read-only.

Commit 3dc7fe0

Browse files
authored
Merge pull request #210 from iciclespider/port-forward
Implement port forwarding.
2 parents 471a678 + 5d39d0d commit 3dc7fe0

File tree

3 files changed

+211
-4
lines changed

3 files changed

+211
-4
lines changed

stream/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -12,4 +12,4 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from .stream import stream
15+
from .stream import stream, portforward

stream/stream.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,12 @@
1717
from . import ws_client
1818

1919

20-
def _websocket_reqeust(websocket_request, api_method, *args, **kwargs):
20+
def _websocket_reqeust(websocket_request, force_kwargs, api_method, *args, **kwargs):
2121
"""Override the ApiClient.request method with an alternative websocket based
2222
method and call the supplied Kubernetes API method with that in place."""
23+
if force_kwargs:
24+
for kwarg, value in force_kwargs.items():
25+
kwargs[kwarg] = value
2326
api_client = api_method.__self__.api_client
2427
# old generated code's api client has config. new ones has configuration
2528
try:
@@ -34,4 +37,5 @@ def _websocket_reqeust(websocket_request, api_method, *args, **kwargs):
3437
api_client.request = prev_request
3538

3639

37-
stream = functools.partial(_websocket_reqeust, ws_client.websocket_call)
40+
stream = functools.partial(_websocket_reqeust, ws_client.websocket_call, None)
41+
portforward = functools.partial(_websocket_reqeust, ws_client.portforward_call, {'_preload_content':False})

stream/ws_client.py

+204-1
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,14 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from kubernetes.client.rest import ApiException
15+
from kubernetes.client.rest import ApiException, ApiValueError
1616

1717
import certifi
1818
import collections
1919
import select
20+
import socket
2021
import ssl
22+
import threading
2123
import time
2224

2325
import six
@@ -225,6 +227,174 @@ def close(self, **kwargs):
225227
WSResponse = collections.namedtuple('WSResponse', ['data'])
226228

227229

230+
class PortForward:
231+
def __init__(self, websocket, ports):
232+
"""A websocket client with support for port forwarding.
233+
234+
Port Forward command sends on 2 channels per port, a read/write
235+
data channel and a read only error channel. Both channels are sent an
236+
initial frame contaning the port number that channel is associated with.
237+
"""
238+
239+
self.websocket = websocket
240+
self.local_ports = {}
241+
for ix, port_number in enumerate(ports):
242+
self.local_ports[port_number] = self._Port(ix, port_number)
243+
# There is a thread run per PortForward instance which performs the translation between the
244+
# raw socket data sent by the python application and the websocket protocol. This thread
245+
# terminates after either side has closed all ports, and after flushing all pending data.
246+
proxy = threading.Thread(
247+
name="Kubernetes port forward proxy: %s" % ', '.join([str(port) for port in ports]),
248+
target=self._proxy
249+
)
250+
proxy.daemon = True
251+
proxy.start()
252+
253+
@property
254+
def connected(self):
255+
return self.websocket.connected
256+
257+
def socket(self, port_number):
258+
if port_number not in self.local_ports:
259+
raise ValueError("Invalid port number")
260+
return self.local_ports[port_number].socket
261+
262+
def error(self, port_number):
263+
if port_number not in self.local_ports:
264+
raise ValueError("Invalid port number")
265+
return self.local_ports[port_number].error
266+
267+
def close(self):
268+
for port in self.local_ports.values():
269+
port.socket.close()
270+
271+
class _Port:
272+
def __init__(self, ix, port_number):
273+
# The remote port number
274+
self.port_number = port_number
275+
# The websocket channel byte number for this port
276+
self.channel = six.int2byte(ix * 2)
277+
# A socket pair is created to provide a means of translating the data flow
278+
# between the python application and the kubernetes websocket. The self.python
279+
# half of the socket pair is used by the _proxy method to receive and send data
280+
# to the running python application.
281+
s, self.python = socket.socketpair(socket.AF_UNIX, socket.SOCK_STREAM)
282+
# The self.socket half of the pair is used by the python application to send
283+
# and receive data to the eventual pod port. It is wrapped in the _Socket class
284+
# because a socket pair is an AF_UNIX socket, not a AF_INET socket. This allows
285+
# intercepting setting AF_INET socket options that would error against an AF_UNIX
286+
# socket.
287+
self.socket = self._Socket(s)
288+
# Data accumulated from the websocket to be sent to the python application.
289+
self.data = b''
290+
# All data sent from kubernetes on the port error channel.
291+
self.error = None
292+
293+
class _Socket:
294+
def __init__(self, socket):
295+
self._socket = socket
296+
297+
def __getattr__(self, name):
298+
return getattr(self._socket, name)
299+
300+
def setsockopt(self, level, optname, value):
301+
# The following socket option is not valid with a socket created from socketpair,
302+
# and is set by the http.client.HTTPConnection.connect method.
303+
if level == socket.IPPROTO_TCP and optname == socket.TCP_NODELAY:
304+
return
305+
self._socket.setsockopt(level, optname, value)
306+
307+
# Proxy all socket data between the python code and the kubernetes websocket.
308+
def _proxy(self):
309+
channel_ports = []
310+
channel_initialized = []
311+
local_ports = {}
312+
for port in self.local_ports.values():
313+
# Setup the data channel for this port number
314+
channel_ports.append(port)
315+
channel_initialized.append(False)
316+
# Setup the error channel for this port number
317+
channel_ports.append(port)
318+
channel_initialized.append(False)
319+
port.python.setblocking(True)
320+
local_ports[port.python] = port
321+
# The data to send on the websocket socket
322+
kubernetes_data = b''
323+
while True:
324+
rlist = [] # List of sockets to read from
325+
wlist = [] # List of sockets to write to
326+
if self.websocket.connected:
327+
rlist.append(self.websocket)
328+
if kubernetes_data:
329+
wlist.append(self.websocket)
330+
local_all_closed = True
331+
for port in self.local_ports.values():
332+
if port.python.fileno() != -1:
333+
if port.error or not self.websocket.connected:
334+
if port.data:
335+
wlist.append(port.python)
336+
local_all_closed = False
337+
else:
338+
port.python.close()
339+
else:
340+
rlist.append(port.python)
341+
if port.data:
342+
wlist.append(port.python)
343+
local_all_closed = False
344+
if local_all_closed and not (self.websocket.connected and kubernetes_data):
345+
self.websocket.close()
346+
return
347+
r, w, _ = select.select(rlist, wlist, [])
348+
for sock in r:
349+
if sock == self.websocket:
350+
opcode, frame = self.websocket.recv_data_frame(True)
351+
if opcode == ABNF.OPCODE_BINARY:
352+
if not frame.data:
353+
raise RuntimeError("Unexpected frame data size")
354+
channel = six.byte2int(frame.data)
355+
if channel >= len(channel_ports):
356+
raise RuntimeError("Unexpected channel number: %s" % channel)
357+
port = channel_ports[channel]
358+
if channel_initialized[channel]:
359+
if channel % 2:
360+
if port.error is None:
361+
port.error = ''
362+
port.error += frame.data[1:].decode()
363+
else:
364+
port.data += frame.data[1:]
365+
else:
366+
if len(frame.data) != 3:
367+
raise RuntimeError(
368+
"Unexpected initial channel frame data size"
369+
)
370+
port_number = six.byte2int(frame.data[1:2]) + (six.byte2int(frame.data[2:3]) * 256)
371+
if port_number != port.port_number:
372+
raise RuntimeError(
373+
"Unexpected port number in initial channel frame: %s" % port_number
374+
)
375+
channel_initialized[channel] = True
376+
elif opcode not in (ABNF.OPCODE_PING, ABNF.OPCODE_PONG, ABNF.OPCODE_CLOSE):
377+
raise RuntimeError("Unexpected websocket opcode: %s" % opcode)
378+
else:
379+
port = local_ports[sock]
380+
data = port.python.recv(1024 * 1024)
381+
if data:
382+
kubernetes_data += ABNF.create_frame(
383+
port.channel + data,
384+
ABNF.OPCODE_BINARY,
385+
).format()
386+
else:
387+
port.python.close()
388+
for sock in w:
389+
if sock == self.websocket:
390+
sent = self.websocket.sock.send(kubernetes_data)
391+
kubernetes_data = kubernetes_data[sent:]
392+
else:
393+
port = local_ports[sock]
394+
sent = port.python.send(port.data)
395+
port.data = port.data[sent:]
396+
397+
228398
def get_websocket_url(url, query_params=None):
229399
parsed_url = urlparse(url)
230400
parts = list(parsed_url)
@@ -302,3 +472,36 @@ def websocket_call(configuration, _method, url, **kwargs):
302472
return WSResponse('%s' % ''.join(client.read_all()))
303473
except (Exception, KeyboardInterrupt, SystemExit) as e:
304474
raise ApiException(status=0, reason=str(e))
475+
476+
477+
def portforward_call(configuration, _method, url, **kwargs):
478+
"""An internal function to be called in api-client when a websocket
479+
connection is required for port forwarding. args and kwargs are the
480+
parameters of apiClient.request method."""
481+
482+
query_params = kwargs.get("query_params")
483+
484+
ports = []
485+
for param, value in query_params:
486+
if param == 'ports':
487+
for port in value.split(','):
488+
try:
489+
port_number = int(port)
490+
except ValueError:
491+
raise ApiValueError("Invalid port number: %s" % port)
492+
if not (0 < port_number < 65536):
493+
raise ApiValueError("Port number must be between 0 and 65536: %s" % port)
494+
if port_number in ports:
495+
raise ApiValueError("Duplicate port numbers: %s" % port)
496+
ports.append(port_number)
497+
if not ports:
498+
raise ApiValueError("Missing required parameter `ports`")
499+
500+
url = get_websocket_url(url, query_params)
501+
headers = kwargs.get("headers")
502+
503+
try:
504+
websocket = create_websocket(configuration, url, headers)
505+
return PortForward(websocket, ports)
506+
except (Exception, KeyboardInterrupt, SystemExit) as e:
507+
raise ApiException(status=0, reason=str(e))

0 commit comments

Comments
 (0)