diff --git a/CHANGELOG.md b/CHANGELOG.md index 34cdfc1a5f..3ded7b71c7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,6 @@ +# v1.0.0b2 +- Support exec calls in both interactive and non-interactive mode #58 + # v1.0.0b1 - Support insecure-skip-tls-verify config flag #99 diff --git a/examples/exec.py b/examples/exec.py new file mode 100644 index 0000000000..f9b21b6349 --- /dev/null +++ b/examples/exec.py @@ -0,0 +1,94 @@ +import time + +from kubernetes import config +from kubernetes.client import configuration +from kubernetes.client.apis import core_v1_api +from kubernetes.client.rest import ApiException + +config.load_kube_config() +configuration.assert_hostname = False +api = core_v1_api.CoreV1Api() +name = 'busybox-test' + +resp = None +try: + resp = api.read_namespaced_pod(name=name, + namespace='default') +except ApiException as e: + if e.status != 404: + print("Unknown error: %s" % e) + exit(1) + +if not resp: + print("Pod %s does not exits. Creating it..." % name) + pod_manifest = { + 'apiVersion': 'v1', + 'kind': 'Pod', + 'metadata': { + 'name': name + }, + 'spec': { + 'containers': [{ + 'image': 'busybox', + 'name': 'sleep', + "args": [ + "/bin/sh", + "-c", + "while true;do date;sleep 5; done" + ] + }] + } + } + resp = api.create_namespaced_pod(body=pod_manifest, + namespace='default') + while True: + resp = api.read_namespaced_pod(name=name, + namespace='default') + if resp.status.phase != 'Pending': + break + time.sleep(1) + print("Done.") + + +# calling exec and wait for response. +exec_command = [ + '/bin/sh', + '-c', + 'echo This message goes to stderr >&2; echo This message goes to stdout'] +resp = api.connect_get_namespaced_pod_exec(name, 'default', + command=exec_command, + stderr=True, stdin=False, + stdout=True, tty=False) +print("Response: " + resp) + +# Calling exec interactively. +exec_command = ['/bin/sh'] +resp = api.connect_get_namespaced_pod_exec(name, 'default', + command=exec_command, + stderr=True, stdin=True, + stdout=True, tty=False, + + _preload_content=False) +commands = [ + "echo test1", + "echo \"This message goes to stderr\" >&2", +] +while resp.is_open(): + resp.update(timeout=1) + if resp.peek_stdout(): + print("STDOUT: %s" % resp.read_stdout()) + if resp.peek_stderr(): + print("STDERR: %s" % resp.read_stderr()) + if commands: + c = commands.pop(0) + print("Running command... %s\n" % c) + resp.write_stdin(c + "\n") + else: + break + +resp.write_stdin("date\n") +sdate = resp.readline_stdout(timeout=3) +print("Server date command returns: %s" % sdate) +resp.write_stdin("whoami\n") +user = resp.readline_stdout(timeout=3) +print("Server user is: %s" % user) diff --git a/kubernetes/client/api_client.py b/kubernetes/client/api_client.py index 6dbe7137d3..7dca16a161 100644 --- a/kubernetes/client/api_client.py +++ b/kubernetes/client/api_client.py @@ -347,12 +347,12 @@ def request(self, method, url, query_params=None, headers=None, # FIXME(dims) : We need a better way to figure out which # calls end up using web sockets if url.endswith('/exec') and (method == "GET" or method == "POST"): - return ws_client.GET(self.config, - url, - query_params=query_params, - _request_timeout=_request_timeout, - headers=headers) - + return ws_client.websocket_call(self.config, + url, + query_params=query_params, + _request_timeout=_request_timeout, + _preload_content=_preload_content, + headers=headers) if method == "GET": return self.rest_client.GET(url, query_params=query_params, diff --git a/kubernetes/client/ws_client.py b/kubernetes/client/ws_client.py index b143400bee..ceaaa72fe8 100644 --- a/kubernetes/client/ws_client.py +++ b/kubernetes/client/ws_client.py @@ -12,33 +12,40 @@ from .rest import ApiException +import select import certifi +import time import collections -import websocket +from websocket import WebSocket, ABNF, enableTrace import six import ssl from six.moves.urllib.parse import urlencode from six.moves.urllib.parse import quote_plus +STDIN_CHANNEL = 0 +STDOUT_CHANNEL = 1 +STDERR_CHANNEL = 2 + class WSClient: def __init__(self, configuration, url, headers): - self.messages = [] - self.errors = [] - websocket.enableTrace(False) - header = None + """A websocket client with support for channels. + + Exec command uses different channels for different streams. for + example, 0 is stdin, 1 is stdout and 2 is stderr. Some other API calls + like port forwarding can forward different pods' streams to different + channels. + """ + enableTrace(False) + header = [] + self._connected = False + self._channels = {} + self._all = "" # We just need to pass the Authorization, ignore all the other # http headers we get from the generated code - if 'Authorization' in headers: - header = "Authorization: %s" % headers['Authorization'] - - self.ws = websocket.WebSocketApp(url, - on_message=self.on_message, - on_error=self.on_error, - on_close=self.on_close, - header=[header] if header else None) - self.ws.on_open = self.on_open + if headers and 'authorization' in headers: + header.append("authorization: %s" % headers['authorization']) if url.startswith('wss://') and configuration.verify_ssl: ssl_opts = { @@ -52,30 +59,145 @@ def __init__(self, configuration, url, headers): else: ssl_opts = {'cert_reqs': ssl.CERT_NONE} - self.ws.run_forever(sslopt=ssl_opts) - - def on_message(self, ws, message): - if message[0] == '\x01': - message = message[1:] - if message: - if six.PY3 and isinstance(message, six.binary_type): - message = message.decode('utf-8') - self.messages.append(message) - - def on_error(self, ws, error): - self.errors.append(error) - - def on_close(self, ws): - pass - - def on_open(self, ws): - pass + self.sock = WebSocket(sslopt=ssl_opts, skip_utf8_validation=False) + self.sock.connect(url, header=header) + self._connected = True + + def peek_channel(self, channel, timeout=0): + """Peek a channel and return part of the input, + empty string otherwise.""" + self.update(timeout=timeout) + if channel in self._channels: + return self._channels[channel] + return "" + + def read_channel(self, channel, timeout=0): + """Read data from a channel.""" + if channel not in self._channels: + ret = self.peek_channel(channel, timeout) + else: + ret = self._channels[channel] + if channel in self._channels: + del self._channels[channel] + return ret + + def readline_channel(self, channel, timeout=None): + """Read a line from a channel.""" + if timeout is None: + timeout = float("inf") + start = time.time() + while self.is_open() and time.time() - start < timeout: + if channel in self._channels: + data = self._channels[channel] + if "\n" in data: + index = data.find("\n") + ret = data[:index] + data = data[index+1:] + if data: + self._channels[channel] = data + else: + del self._channels[channel] + return ret + self.update(timeout=(timeout - time.time() + start)) + + def write_channel(self, channel, data): + """Write data to a channel.""" + self.sock.send(chr(channel) + data) + + def peek_stdout(self, timeout=0): + """Same as peek_channel with channel=1.""" + return self.peek_channel(STDOUT_CHANNEL, timeout=timeout) + + def read_stdout(self, timeout=None): + """Same as read_channel with channel=1.""" + return self.read_channel(STDOUT_CHANNEL, timeout=timeout) + + def readline_stdout(self, timeout=None): + """Same as readline_channel with channel=1.""" + return self.readline_channel(STDOUT_CHANNEL, timeout=timeout) + + def peek_stderr(self, timeout=0): + """Same as peek_channel with channel=2.""" + return self.peek_channel(STDERR_CHANNEL, timeout=timeout) + + def read_stderr(self, timeout=None): + """Same as read_channel with channel=2.""" + return self.read_channel(STDERR_CHANNEL, timeout=timeout) + + def readline_stderr(self, timeout=None): + """Same as readline_channel with channel=2.""" + return self.readline_channel(STDERR_CHANNEL, timeout=timeout) + + def read_all(self): + """Read all of the inputs with the same order they recieved. The channel + information would be part of the string. This is useful for + non-interactive call where a set of command passed to the API call and + their result is needed after the call is concluded. + + TODO: Maybe we can process this and return a more meaningful map with + channels mapped for each input. + """ + out = self._all + self._all = "" + self._channels = {} + return out + + def is_open(self): + """True if the connection is still alive.""" + return self._connected + + def write_stdin(self, data): + """The same as write_channel with channel=0.""" + self.write_channel(STDIN_CHANNEL, data) + + def update(self, timeout=0): + """Update channel buffers with at most one complete frame of input.""" + if not self.is_open(): + return + if not self.sock.connected: + self._connected = False + return + r, _, _ = select.select( + (self.sock.sock, ), (), (), timeout) + if r: + op_code, frame = self.sock.recv_data_frame(True) + if op_code == ABNF.OPCODE_CLOSE: + self._connected = False + return + elif op_code == ABNF.OPCODE_BINARY or op_code == ABNF.OPCODE_TEXT: + data = frame.data + if six.PY3: + data = data.decode("utf-8") + self._all += data + if len(data) > 1: + channel = ord(data[0]) + data = data[1:] + if data: + if channel not in self._channels: + self._channels[channel] = data + else: + self._channels[channel] += data + + def run_forever(self, timeout=None): + """Wait till connection is closed or timeout reached. Buffer any input + received during this time.""" + if timeout: + start = time.time() + while self.is_open() and time.time() - start < timeout: + self.update(timeout=(timeout - time.time() + start)) + else: + while self.is_open(): + self.update(timeout=None) WSResponse = collections.namedtuple('WSResponse', ['data']) -def GET(configuration, url, query_params, _request_timeout, headers): +def websocket_call(configuration, url, query_params, _request_timeout, + _preload_content, headers): + """An internal function to be called in api-client when a websocket + connection is required.""" + # switch protocols from http to websocket url = url.replace('http://', 'ws://') url = url.replace('https://', 'wss://') @@ -105,10 +227,11 @@ def GET(configuration, url, query_params, _request_timeout, headers): else: url += '&command=' + quote_plus(commands) - client = WSClient(configuration, url, headers) - if client.errors: - raise ApiException( - status=0, - reason='\n'.join([str(error) for error in client.errors]) - ) - return WSResponse('%s' % ''.join(client.messages)) + try: + client = WSClient(configuration, url, headers) + if not _preload_content: + return client + client.run_forever(timeout=_request_timeout) + return WSResponse('%s' % ''.join(client.read_all())) + except (Exception, KeyboardInterrupt, SystemExit) as e: + raise ApiException(status=0, reason=str(e)) diff --git a/kubernetes/e2e_test/base.py b/kubernetes/e2e_test/base.py index 5f04ab7e8e..ee19b14ac4 100644 --- a/kubernetes/e2e_test/base.py +++ b/kubernetes/e2e_test/base.py @@ -42,4 +42,5 @@ def get_e2e_configuration(): if config.host is None: raise unittest.SkipTest('Unable to find a running Kubernetes instance') print('Running test against : %s' % config.host) + config.assert_hostname = False return config diff --git a/kubernetes/e2e_test/test_client.py b/kubernetes/e2e_test/test_client.py index 8bc9b3d30a..6f19fbdb93 100644 --- a/kubernetes/e2e_test/test_client.py +++ b/kubernetes/e2e_test/test_client.py @@ -18,10 +18,14 @@ from kubernetes.client import api_client from kubernetes.client.apis import core_v1_api -from kubernetes.client.configuration import configuration from kubernetes.e2e_test import base +def short_uuid(): + id = str(uuid.uuid4()) + return id[-12:] + + class TestClient(unittest.TestCase): @classmethod @@ -32,7 +36,7 @@ def test_pod_apis(self): client = api_client.ApiClient(config=self.config) api = core_v1_api.CoreV1Api(client) - name = 'busybox-test-' + str(uuid.uuid4()) + name = 'busybox-test-' + short_uuid() pod_manifest = { 'apiVersion': 'v1', 'kind': 'Pod', @@ -68,7 +72,7 @@ def test_pod_apis(self): exec_command = ['/bin/sh', '-c', - 'for i in $(seq 1 3); do date; sleep 1; done'] + 'for i in $(seq 1 3); do date; done'] resp = api.connect_get_namespaced_pod_exec(name, 'default', command=exec_command, stderr=False, stdin=False, @@ -78,12 +82,29 @@ def test_pod_apis(self): exec_command = 'uptime' resp = api.connect_post_namespaced_pod_exec(name, 'default', - command=exec_command, - stderr=False, stdin=False, - stdout=True, tty=False) + command=exec_command, + stderr=False, stdin=False, + stdout=True, tty=False) print('EXEC response : %s' % resp) self.assertEqual(1, len(resp.splitlines())) + resp = api.connect_post_namespaced_pod_exec(name, 'default', + command='/bin/sh', + stderr=True, stdin=True, + stdout=True, tty=False, + _preload_content=False) + resp.write_stdin("echo test string 1\n") + line = resp.readline_stdout(timeout=5) + self.assertFalse(resp.peek_stderr()) + self.assertEqual("test string 1", line) + resp.write_stdin("echo test string 2 >&2\n") + line = resp.readline_stderr(timeout=5) + self.assertFalse(resp.peek_stdout()) + self.assertEqual("test string 2", line) + resp.write_stdin("exit\n") + resp.update(timeout=5) + self.assertFalse(resp.is_open()) + number_of_pods = len(api.list_pod_for_all_namespaces().items) self.assertTrue(number_of_pods > 0) @@ -94,7 +115,7 @@ def test_service_apis(self): client = api_client.ApiClient(config=self.config) api = core_v1_api.CoreV1Api(client) - name = 'frontend-' + str(uuid.uuid4()) + name = 'frontend-' + short_uuid() service_manifest = {'apiVersion': 'v1', 'kind': 'Service', 'metadata': {'labels': {'name': name}, @@ -133,7 +154,7 @@ def test_replication_controller_apis(self): client = api_client.ApiClient(config=self.config) api = core_v1_api.CoreV1Api(client) - name = 'frontend-' + str(uuid.uuid4()) + name = 'frontend-' + short_uuid() rc_manifest = { 'apiVersion': 'v1', 'kind': 'ReplicationController', @@ -166,7 +187,7 @@ def test_configmap_apis(self): client = api_client.ApiClient(config=self.config) api = core_v1_api.CoreV1Api(client) - name = 'test-configmap-' + str(uuid.uuid4()) + name = 'test-configmap-' + short_uuid() test_configmap = { "kind": "ConfigMap", "apiVersion": "v1", @@ -195,7 +216,7 @@ def test_configmap_apis(self): resp = api.delete_namespaced_config_map( name=name, body={}, namespace='default') - resp = api.list_namespaced_config_map('kube-system', pretty=True) + resp = api.list_namespaced_config_map('default', pretty=True) self.assertEqual([], resp.items) def test_node_apis(self):