diff --git a/api_client.py b/api_client.py new file mode 100644 index 00000000..0e5e14ab --- /dev/null +++ b/api_client.py @@ -0,0 +1,647 @@ +# coding: utf-8 + +""" +Copyright 2016 SmartBear Software + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + ref: https://github.com/swagger-api/swagger-codegen +""" + +from __future__ import absolute_import + +from . import models +from . import ws_client +from .rest import RESTClientObject +from .rest import ApiException + +import os +import re +import json +import mimetypes +import tempfile +import threading + +from datetime import datetime +from datetime import date + +# python 2 and python 3 compatibility library +from six import PY3, integer_types, iteritems, text_type +from six.moves.urllib.parse import quote + +from .configuration import configuration + + +class ApiClient(object): + """ + Generic API client for Swagger client library builds. + + Swagger generic API client. This client handles the client- + server communication, and is invariant across implementations. Specifics of + the methods and models for each application are generated from the Swagger + templates. + + NOTE: This class is auto generated by the swagger code generator program. + Ref: https://github.com/swagger-api/swagger-codegen + Do not edit the class manually. + + :param host: The base path for the server to call. + :param header_name: a header to pass when making calls to the API. + :param header_value: a header value to pass when making calls to the API. + """ + def __init__(self, host=None, header_name=None, header_value=None, + cookie=None, config=configuration): + + """ + Constructor of the class. + """ + self.config = config + self.rest_client = RESTClientObject(config=self.config) + self.default_headers = {} + if header_name is not None: + self.default_headers[header_name] = header_value + if host is None: + self.host = self.config.host + else: + self.host = host + self.cookie = cookie + # Set default User-Agent. + self.user_agent = 'Swagger-Codegen/1.0.0-snapshot/python' + + @property + def user_agent(self): + """ + Gets user agent. + """ + return self.default_headers['User-Agent'] + + @user_agent.setter + def user_agent(self, value): + """ + Sets user agent. + """ + self.default_headers['User-Agent'] = value + + def set_default_header(self, header_name, header_value): + self.default_headers[header_name] = header_value + + def __call_api(self, resource_path, method, + path_params=None, query_params=None, header_params=None, + body=None, post_params=None, files=None, + response_type=None, auth_settings=None, callback=None, + _return_http_data_only=None, collection_formats=None, _preload_content=True, + _request_timeout=None): + + # header parameters + header_params = header_params or {} + header_params.update(self.default_headers) + if self.cookie: + header_params['Cookie'] = self.cookie + if header_params: + header_params = self.sanitize_for_serialization(header_params) + header_params = dict(self.parameters_to_tuples(header_params, + collection_formats)) + + # path parameters + if path_params: + path_params = self.sanitize_for_serialization(path_params) + path_params = self.parameters_to_tuples(path_params, + collection_formats) + for k, v in path_params: + resource_path = resource_path.replace( + '{%s}' % k, quote(str(v))) + + # query parameters + if query_params: + query_params = self.sanitize_for_serialization(query_params) + query_params = self.parameters_to_tuples(query_params, + collection_formats) + + # post parameters + if post_params or files: + post_params = self.prepare_post_parameters(post_params, files) + post_params = self.sanitize_for_serialization(post_params) + post_params = self.parameters_to_tuples(post_params, + collection_formats) + + # auth setting + self.update_params_for_auth(header_params, query_params, auth_settings) + + # body + if body: + body = self.sanitize_for_serialization(body) + + # request url + url = self.host + resource_path + + # perform request and return response + response_data = self.request(method, url, + query_params=query_params, + headers=header_params, + post_params=post_params, body=body, + _preload_content=_preload_content, + _request_timeout=_request_timeout) + + self.last_response = response_data + + return_data = response_data + if _preload_content: + # deserialize response data + if response_type: + return_data = self.deserialize(response_data, response_type) + else: + return_data = None + + if callback: + callback(return_data) if _return_http_data_only else callback((return_data, response_data.status, response_data.getheaders())) + elif _return_http_data_only: + return (return_data) + else: + return (return_data, response_data.status, response_data.getheaders()) + + def sanitize_for_serialization(self, obj): + """ + Builds a JSON POST object. + + If obj is None, return None. + If obj is str, int, long, float, bool, return directly. + If obj is datetime.datetime, datetime.date + convert to string in iso8601 format. + If obj is list, sanitize each element in the list. + If obj is dict, return the dict. + If obj is swagger model, return the properties dict. + + :param obj: The data to serialize. + :return: The serialized form of data. + """ + types = (str, float, bool, bytes) + tuple(integer_types) + (text_type,) + if isinstance(obj, type(None)): + return None + elif isinstance(obj, types): + return obj + elif isinstance(obj, list): + return [self.sanitize_for_serialization(sub_obj) + for sub_obj in obj] + elif isinstance(obj, tuple): + return tuple(self.sanitize_for_serialization(sub_obj) + for sub_obj in obj) + elif isinstance(obj, (datetime, date)): + return obj.isoformat() + else: + if isinstance(obj, dict): + obj_dict = obj + else: + # Convert model obj to dict except + # attributes `swagger_types`, `attribute_map` + # and attributes which value is not None. + # Convert attribute name to json key in + # model definition for request. + obj_dict = {obj.attribute_map[attr]: getattr(obj, attr) + for attr, _ in iteritems(obj.swagger_types) + if getattr(obj, attr) is not None} + + return {key: self.sanitize_for_serialization(val) + for key, val in iteritems(obj_dict)} + + def deserialize(self, response, response_type): + """ + Deserializes response into an object. + + :param response: RESTResponse object to be deserialized. + :param response_type: class literal for + deserialized object, or string of class name. + + :return: deserialized object. + """ + # handle file downloading + # save response body into a tmp file and return the instance + if "file" == response_type: + return self.__deserialize_file(response) + + # fetch data from response object + try: + data = json.loads(response.data) + except ValueError: + data = response.data + + return self.__deserialize(data, response_type) + + def __deserialize(self, data, klass): + """ + Deserializes dict, list, str into an object. + + :param data: dict, list or str. + :param klass: class literal, or string of class name. + + :return: object. + """ + if data is None: + return None + + if type(klass) == str: + if klass.startswith('list['): + sub_kls = re.match('list\[(.*)\]', klass).group(1) + return [self.__deserialize(sub_data, sub_kls) + for sub_data in data] + + if klass.startswith('dict('): + sub_kls = re.match('dict\(([^,]*), (.*)\)', klass).group(2) + return {k: self.__deserialize(v, sub_kls) + for k, v in iteritems(data)} + + # convert str to class + # for native types + if klass in ['int', 'float', 'str', 'bool', + "date", 'datetime', "object"]: + klass = eval(klass) + elif klass == 'long': + klass = int if PY3 else long + # for model types + else: + klass = eval('models.' + klass) + + if klass in integer_types or klass in (float, str, bool): + return self.__deserialize_primitive(data, klass) + elif klass == object: + return self.__deserialize_object(data) + elif klass == date: + return self.__deserialize_date(data) + elif klass == datetime: + return self.__deserialize_datatime(data) + else: + return self.__deserialize_model(data, klass) + + def call_api(self, resource_path, method, + path_params=None, query_params=None, header_params=None, + body=None, post_params=None, files=None, + response_type=None, auth_settings=None, callback=None, + _return_http_data_only=None, collection_formats=None, _preload_content=True, + _request_timeout=None): + """ + Makes the HTTP request (synchronous) and return the deserialized data. + To make an async request, define a function for callback. + + :param resource_path: Path to method endpoint. + :param method: Method to call. + :param path_params: Path parameters in the url. + :param query_params: Query parameters in the url. + :param header_params: Header parameters to be + placed in the request header. + :param body: Request body. + :param post_params dict: Request post form parameters, + for `application/x-www-form-urlencoded`, `multipart/form-data`. + :param auth_settings list: Auth Settings names for the request. + :param response: Response data type. + :param files dict: key -> filename, value -> filepath, + for `multipart/form-data`. + :param callback function: Callback function for asynchronous request. + If provide this parameter, + the request will be called asynchronously. + :param _return_http_data_only: response data without head status code and headers + :param collection_formats: dict of collection formats for path, query, + header, and post parameters. + :param _preload_content: if False, the urllib3.HTTPResponse object will be returned without + reading/decoding response data. Default is True. + :param _request_timeout: timeout setting for this request. If one number provided, it will be total request + timeout. It can also be a pair (tuple) of (connection, read) timeouts. + :return: + If provide parameter callback, + the request will be called asynchronously. + The method will return the request thread. + If parameter callback is None, + then the method will return the response directly. + """ + if callback is None: + return self.__call_api(resource_path, method, + path_params, query_params, header_params, + body, post_params, files, + response_type, auth_settings, callback, + _return_http_data_only, collection_formats, _preload_content, _request_timeout) + else: + thread = threading.Thread(target=self.__call_api, + args=(resource_path, method, + path_params, query_params, + header_params, body, + post_params, files, + response_type, auth_settings, + callback, _return_http_data_only, + collection_formats, _preload_content, _request_timeout)) + thread.start() + return thread + + def request(self, method, url, query_params=None, headers=None, + post_params=None, body=None, _preload_content=True, _request_timeout=None): + """ + Makes the HTTP request using RESTClient. + """ + # FIXME(dims) : We need a better way to figure out which + # calls end up using web sockets + if (url.endswith('/exec') or url.endswith('/attach')) and (method == "GET" or method == "POST"): + return ws_client.websocket_call(self.config, + url, + query_params=query_params, + _request_timeout=_request_timeout, + _preload_content=_preload_content, + headers=headers) + if method == "GET": + return self.rest_client.GET(url, + query_params=query_params, + _preload_content=_preload_content, + _request_timeout=_request_timeout, + headers=headers) + elif method == "HEAD": + return self.rest_client.HEAD(url, + query_params=query_params, + _preload_content=_preload_content, + _request_timeout=_request_timeout, + headers=headers) + elif method == "OPTIONS": + return self.rest_client.OPTIONS(url, + query_params=query_params, + headers=headers, + post_params=post_params, + _preload_content=_preload_content, + _request_timeout=_request_timeout, + body=body) + elif method == "POST": + return self.rest_client.POST(url, + query_params=query_params, + headers=headers, + post_params=post_params, + _preload_content=_preload_content, + _request_timeout=_request_timeout, + body=body) + elif method == "PUT": + return self.rest_client.PUT(url, + query_params=query_params, + headers=headers, + post_params=post_params, + _preload_content=_preload_content, + _request_timeout=_request_timeout, + body=body) + elif method == "PATCH": + return self.rest_client.PATCH(url, + query_params=query_params, + headers=headers, + post_params=post_params, + _preload_content=_preload_content, + _request_timeout=_request_timeout, + body=body) + elif method == "DELETE": + return self.rest_client.DELETE(url, + query_params=query_params, + headers=headers, + _preload_content=_preload_content, + _request_timeout=_request_timeout, + body=body) + else: + raise ValueError( + "http method must be `GET`, `HEAD`, `OPTIONS`," + " `POST`, `PATCH`, `PUT` or `DELETE`." + ) + + def parameters_to_tuples(self, params, collection_formats): + """ + Get parameters as list of tuples, formatting collections. + + :param params: Parameters as dict or list of two-tuples + :param dict collection_formats: Parameter collection formats + :return: Parameters as list of tuples, collections formatted + """ + new_params = [] + if collection_formats is None: + collection_formats = {} + for k, v in iteritems(params) if isinstance(params, dict) else params: + if k in collection_formats: + collection_format = collection_formats[k] + if collection_format == 'multi': + new_params.extend((k, value) for value in v) + else: + if collection_format == 'ssv': + delimiter = ' ' + elif collection_format == 'tsv': + delimiter = '\t' + elif collection_format == 'pipes': + delimiter = '|' + else: # csv is the default + delimiter = ',' + new_params.append( + (k, delimiter.join(str(value) for value in v))) + else: + new_params.append((k, v)) + return new_params + + def prepare_post_parameters(self, post_params=None, files=None): + """ + Builds form parameters. + + :param post_params: Normal form parameters. + :param files: File parameters. + :return: Form parameters with files. + """ + params = [] + + if post_params: + params = post_params + + if files: + for k, v in iteritems(files): + if not v: + continue + file_names = v if type(v) is list else [v] + for n in file_names: + with open(n, 'rb') as f: + filename = os.path.basename(f.name) + filedata = f.read() + mimetype = mimetypes.\ + guess_type(filename)[0] or 'application/octet-stream' + params.append(tuple([k, tuple([filename, filedata, mimetype])])) + + return params + + def select_header_accept(self, accepts): + """ + Returns `Accept` based on an array of accepts provided. + + :param accepts: List of headers. + :return: Accept (e.g. application/json). + """ + if not accepts: + return + + accepts = list(map(lambda x: x.lower(), accepts)) + + if 'application/json' in accepts: + return 'application/json' + else: + return ', '.join(accepts) + + def select_header_content_type(self, content_types): + """ + Returns `Content-Type` based on an array of content_types provided. + + :param content_types: List of content-types. + :return: Content-Type (e.g. application/json). + """ + if not content_types: + return 'application/json' + + content_types = list(map(lambda x: x.lower(), content_types)) + + if 'application/json' in content_types or '*/*' in content_types: + return 'application/json' + else: + return content_types[0] + + def update_params_for_auth(self, headers, querys, auth_settings): + """ + Updates header and query params based on authentication setting. + + :param headers: Header parameters dict to be updated. + :param querys: Query parameters tuple list to be updated. + :param auth_settings: Authentication setting identifiers list. + """ + + if not auth_settings: + return + + for auth in auth_settings: + auth_setting = self.config.auth_settings().get(auth) + if auth_setting: + if not auth_setting['value']: + continue + elif auth_setting['in'] == 'header': + headers[auth_setting['key']] = auth_setting['value'] + elif auth_setting['in'] == 'query': + querys.append((auth_setting['key'], auth_setting['value'])) + else: + raise ValueError( + 'Authentication token must be in `query` or `header`' + ) + + def __deserialize_file(self, response): + """ + Saves response body into a file in a temporary folder, + using the filename from the `Content-Disposition` header if provided. + + :param response: RESTResponse. + :return: file path. + """ + fd, path = tempfile.mkstemp(dir=self.config.temp_folder_path) + os.close(fd) + os.remove(path) + + content_disposition = response.getheader("Content-Disposition") + if content_disposition: + filename = re.\ + search(r'filename=[\'"]?([^\'"\s]+)[\'"]?', content_disposition).\ + group(1) + path = os.path.join(os.path.dirname(path), filename) + + with open(path, "w") as f: + f.write(response.data) + + return path + + def __deserialize_primitive(self, data, klass): + """ + Deserializes string to primitive type. + + :param data: str. + :param klass: class literal. + + :return: int, long, float, str, bool. + """ + try: + value = klass(data) + except UnicodeEncodeError: + value = unicode(data) + except TypeError: + value = data + return value + + def __deserialize_object(self, value): + """ + Return a original value. + + :return: object. + """ + return value + + def __deserialize_date(self, string): + """ + Deserializes string to date. + + :param string: str. + :return: date. + """ + if not string: + return None + try: + from dateutil.parser import parse + return parse(string).date() + except ImportError: + return string + except ValueError: + raise ApiException( + status=0, + reason="Failed to parse `{0}` into a date object" + .format(string) + ) + + def __deserialize_datatime(self, string): + """ + Deserializes string to datetime. + + The string should be in iso8601 datetime format. + + :param string: str. + :return: datetime. + """ + if not string: + return None + try: + from dateutil.parser import parse + return parse(string) + except ImportError: + return string + except ValueError: + raise ApiException( + status=0, + reason="Failed to parse `{0}` into a datetime object". + format(string) + ) + + def __deserialize_model(self, data, klass): + """ + Deserializes list or dict to model. + + :param data: dict, list. + :param klass: class literal. + :return: model object. + """ + instance = klass() + + if not instance.swagger_types: + return data + + for attr, attr_type in iteritems(instance.swagger_types): + if data is not None \ + and instance.attribute_map[attr] in data\ + and isinstance(data, (list, dict)): + value = data[instance.attribute_map[attr]] + if value is None: + value = [] if isinstance(data, list) else {} + setattr(instance, attr, self.__deserialize(value, attr_type)) + + return instance diff --git a/config/__init__.py b/config/__init__.py new file mode 100644 index 00000000..3476ff71 --- /dev/null +++ b/config/__init__.py @@ -0,0 +1,18 @@ +# Copyright 2016 The Kubernetes Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .config_exception import ConfigException +from .incluster_config import load_incluster_config +from .kube_config import (list_kube_config_contexts, load_kube_config, + new_client_from_config) diff --git a/config/config b/config/config new file mode 120000 index 00000000..30fa1cea --- /dev/null +++ b/config/config @@ -0,0 +1 @@ +config \ No newline at end of file diff --git a/config/config_exception.py b/config/config_exception.py new file mode 100644 index 00000000..23fab022 --- /dev/null +++ b/config/config_exception.py @@ -0,0 +1,17 @@ +# Copyright 2016 The Kubernetes Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +class ConfigException(Exception): + pass diff --git a/config/incluster_config.py b/config/incluster_config.py new file mode 100644 index 00000000..3ba1113f --- /dev/null +++ b/config/incluster_config.py @@ -0,0 +1,91 @@ +# Copyright 2016 The Kubernetes Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +from kubernetes.client import configuration + +from .config_exception import ConfigException + +SERVICE_HOST_ENV_NAME = "KUBERNETES_SERVICE_HOST" +SERVICE_PORT_ENV_NAME = "KUBERNETES_SERVICE_PORT" +SERVICE_TOKEN_FILENAME = "/var/run/secrets/kubernetes.io/serviceaccount/token" +SERVICE_CERT_FILENAME = "/var/run/secrets/kubernetes.io/serviceaccount/ca.crt" + + +def _join_host_port(host, port): + """Adapted golang's net.JoinHostPort""" + template = "%s:%s" + host_requires_bracketing = ':' in host or '%' in host + if host_requires_bracketing: + template = "[%s]:%s" + return template % (host, port) + + +class InClusterConfigLoader(object): + + def __init__(self, token_filename, + cert_filename, environ=os.environ): + self._token_filename = token_filename + self._cert_filename = cert_filename + self._environ = environ + + def load_and_set(self): + self._load_config() + self._set_config() + + def _load_config(self): + if (SERVICE_HOST_ENV_NAME not in self._environ or + SERVICE_PORT_ENV_NAME not in self._environ): + raise ConfigException("Service host/port is not set.") + + if (not self._environ[SERVICE_HOST_ENV_NAME] or + not self._environ[SERVICE_PORT_ENV_NAME]): + raise ConfigException("Service host/port is set but empty.") + + self.host = ( + "https://" + _join_host_port(self._environ[SERVICE_HOST_ENV_NAME], + self._environ[SERVICE_PORT_ENV_NAME])) + + if not os.path.isfile(self._token_filename): + raise ConfigException("Service token file does not exists.") + + with open(self._token_filename) as f: + self.token = f.read() + if not self.token: + raise ConfigException("Token file exists but empty.") + + if not os.path.isfile(self._cert_filename): + raise ConfigException( + "Service certification file does not exists.") + + with open(self._cert_filename) as f: + if not f.read(): + raise ConfigException("Cert file exists but empty.") + + self.ssl_ca_cert = self._cert_filename + + def _set_config(self): + configuration.host = self.host + configuration.ssl_ca_cert = self.ssl_ca_cert + configuration.api_key['authorization'] = "bearer " + self.token + + +def load_incluster_config(): + """Use the service account kubernetes gives to pods to connect to kubernetes + cluster. It's intended for clients that expect to be running inside a pod + running on kubernetes. It will raise an exception if called from a process + not running in a kubernetes environment.""" + InClusterConfigLoader(token_filename=SERVICE_TOKEN_FILENAME, + cert_filename=SERVICE_CERT_FILENAME).load_and_set() diff --git a/config/incluster_config_test.py b/config/incluster_config_test.py new file mode 100644 index 00000000..622b31b3 --- /dev/null +++ b/config/incluster_config_test.py @@ -0,0 +1,131 @@ +# Copyright 2016 The Kubernetes Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import tempfile +import unittest + +from .config_exception import ConfigException +from .incluster_config import (SERVICE_HOST_ENV_NAME, SERVICE_PORT_ENV_NAME, + InClusterConfigLoader, _join_host_port) + +_TEST_TOKEN = "temp_token" +_TEST_CERT = "temp_cert" +_TEST_HOST = "127.0.0.1" +_TEST_PORT = "80" +_TEST_HOST_PORT = "127.0.0.1:80" +_TEST_IPV6_HOST = "::1" +_TEST_IPV6_HOST_PORT = "[::1]:80" + +_TEST_ENVIRON = {SERVICE_HOST_ENV_NAME: _TEST_HOST, + SERVICE_PORT_ENV_NAME: _TEST_PORT} +_TEST_IPV6_ENVIRON = {SERVICE_HOST_ENV_NAME: _TEST_IPV6_HOST, + SERVICE_PORT_ENV_NAME: _TEST_PORT} + + +class InClusterConfigTest(unittest.TestCase): + + def setUp(self): + self._temp_files = [] + + def tearDown(self): + for f in self._temp_files: + os.remove(f) + + def _create_file_with_temp_content(self, content=""): + handler, name = tempfile.mkstemp() + self._temp_files.append(name) + os.write(handler, str.encode(content)) + os.close(handler) + return name + + def get_test_loader( + self, + token_filename=None, + cert_filename=None, + environ=_TEST_ENVIRON): + if not token_filename: + token_filename = self._create_file_with_temp_content(_TEST_TOKEN) + if not cert_filename: + cert_filename = self._create_file_with_temp_content(_TEST_CERT) + return InClusterConfigLoader( + token_filename=token_filename, + cert_filename=cert_filename, + environ=environ) + + def test_join_host_port(self): + self.assertEqual(_TEST_HOST_PORT, + _join_host_port(_TEST_HOST, _TEST_PORT)) + self.assertEqual(_TEST_IPV6_HOST_PORT, + _join_host_port(_TEST_IPV6_HOST, _TEST_PORT)) + + def test_load_config(self): + cert_filename = self._create_file_with_temp_content(_TEST_CERT) + loader = self.get_test_loader(cert_filename=cert_filename) + loader._load_config() + self.assertEqual("https://" + _TEST_HOST_PORT, loader.host) + self.assertEqual(cert_filename, loader.ssl_ca_cert) + self.assertEqual(_TEST_TOKEN, loader.token) + + def _should_fail_load(self, config_loader, reason): + try: + config_loader.load_and_set() + self.fail("Should fail because %s" % reason) + except ConfigException: + # expected + pass + + def test_no_port(self): + loader = self.get_test_loader( + environ={SERVICE_HOST_ENV_NAME: _TEST_HOST}) + self._should_fail_load(loader, "no port specified") + + def test_empty_port(self): + loader = self.get_test_loader( + environ={SERVICE_HOST_ENV_NAME: _TEST_HOST, + SERVICE_PORT_ENV_NAME: ""}) + self._should_fail_load(loader, "empty port specified") + + def test_no_host(self): + loader = self.get_test_loader( + environ={SERVICE_PORT_ENV_NAME: _TEST_PORT}) + self._should_fail_load(loader, "no host specified") + + def test_empty_host(self): + loader = self.get_test_loader( + environ={SERVICE_HOST_ENV_NAME: "", + SERVICE_PORT_ENV_NAME: _TEST_PORT}) + self._should_fail_load(loader, "empty host specified") + + def test_no_cert_file(self): + loader = self.get_test_loader(cert_filename="not_exists_file_1123") + self._should_fail_load(loader, "cert file does not exists") + + def test_empty_cert_file(self): + loader = self.get_test_loader( + cert_filename=self._create_file_with_temp_content()) + self._should_fail_load(loader, "empty cert file provided") + + def test_no_token_file(self): + loader = self.get_test_loader(token_filename="not_exists_file_1123") + self._should_fail_load(loader, "token file does not exists") + + def test_empty_token_file(self): + loader = self.get_test_loader( + token_filename=self._create_file_with_temp_content()) + self._should_fail_load(loader, "empty token file provided") + + +if __name__ == '__main__': + unittest.main() diff --git a/config/kube_config.py b/config/kube_config.py new file mode 100644 index 00000000..b0ddeaa6 --- /dev/null +++ b/config/kube_config.py @@ -0,0 +1,321 @@ +# Copyright 2016 The Kubernetes Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import atexit +import base64 +import os +import tempfile + +import urllib3 +import yaml +from oauth2client.client import GoogleCredentials + +from kubernetes.client import ApiClient, ConfigurationObject, configuration + +from .config_exception import ConfigException + +KUBE_CONFIG_DEFAULT_LOCATION = '~/.kube/config' +_temp_files = {} + + +def _cleanup_temp_files(): + global _temp_files + for temp_file in _temp_files.values(): + try: + os.remove(temp_file) + except OSError: + pass + _temp_files = {} + + +def _create_temp_file_with_content(content): + if len(_temp_files) == 0: + atexit.register(_cleanup_temp_files) + # Because we may change context several times, try to remember files we + # created and reuse them at a small memory cost. + content_key = str(content) + if content_key in _temp_files: + return _temp_files[content_key] + _, name = tempfile.mkstemp() + _temp_files[content_key] = name + with open(name, 'wb') as fd: + fd.write(content.encode() if isinstance(content, str) else content) + return name + + +class FileOrData(object): + """Utility class to read content of obj[%data_key_name] or file's + content of obj[%file_key_name] and represent it as file or data. + Note that the data is preferred. The obj[%file_key_name] will be used iff + obj['%data_key_name'] is not set or empty. Assumption is file content is + raw data and data field is base64 string. The assumption can be changed + with base64_file_content flag. If set to False, the content of the file + will assumed to be base64 and read as is. The default True value will + result in base64 encode of the file content after read.""" + + def __init__(self, obj, file_key_name, data_key_name=None, + file_base_path="", base64_file_content=True): + if not data_key_name: + data_key_name = file_key_name + "-data" + self._file = None + self._data = None + self._base64_file_content = base64_file_content + if data_key_name in obj: + self._data = obj[data_key_name] + elif file_key_name in obj: + self._file = os.path.normpath( + os.path.join(file_base_path, obj[file_key_name])) + + def as_file(self): + """If obj[%data_key_name] exists, return name of a file with base64 + decoded obj[%data_key_name] content otherwise obj[%file_key_name].""" + use_data_if_no_file = not self._file and self._data + if use_data_if_no_file: + if self._base64_file_content: + self._file = _create_temp_file_with_content( + base64.decodestring(self._data.encode())) + else: + self._file = _create_temp_file_with_content(self._data) + if self._file and not os.path.isfile(self._file): + raise ConfigException("File does not exists: %s" % self._file) + return self._file + + def as_data(self): + """If obj[%data_key_name] exists, Return obj[%data_key_name] otherwise + base64 encoded string of obj[%file_key_name] file content.""" + use_file_if_no_data = not self._data and self._file + if use_file_if_no_data: + with open(self._file) as f: + if self._base64_file_content: + self._data = bytes.decode( + base64.encodestring(str.encode(f.read()))) + else: + self._data = f.read() + return self._data + + +class KubeConfigLoader(object): + + def __init__(self, config_dict, active_context=None, + get_google_credentials=None, + client_configuration=configuration, + config_base_path=""): + self._config = ConfigNode('kube-config', config_dict) + self._current_context = None + self._user = None + self._cluster = None + self.set_active_context(active_context) + self._config_base_path = config_base_path + if get_google_credentials: + self._get_google_credentials = get_google_credentials + else: + self._get_google_credentials = lambda: ( + GoogleCredentials.get_application_default() + .get_access_token().access_token) + self._client_configuration = client_configuration + + def set_active_context(self, context_name=None): + if context_name is None: + context_name = self._config['current-context'] + self._current_context = self._config['contexts'].get_with_name( + context_name) + if self._current_context['context'].safe_get('user'): + self._user = self._config['users'].get_with_name( + self._current_context['context']['user'])['user'] + else: + self._user = None + self._cluster = self._config['clusters'].get_with_name( + self._current_context['context']['cluster'])['cluster'] + + def _load_authentication(self): + """Read authentication from kube-config user section if exists. + + This function goes through various authentication methods in user + section of kube-config and stops if it finds a valid authentication + method. The order of authentication methods is: + + 1. GCP auth-provider + 2. token_data + 3. token field (point to a token file) + 4. username/password + """ + if not self._user: + return + if self._load_gcp_token(): + return + if self._load_user_token(): + return + self._load_user_pass_token() + + def _load_gcp_token(self): + if 'auth-provider' not in self._user: + return + if 'name' not in self._user['auth-provider']: + return + if self._user['auth-provider']['name'] != 'gcp': + return + # Ignore configs in auth-provider and rely on GoogleCredentials + # caching and refresh mechanism. + # TODO: support gcp command based token ("cmd-path" config). + self.token = "Bearer %s" % self._get_google_credentials() + return self.token + + def _load_user_token(self): + token = FileOrData( + self._user, 'tokenFile', 'token', + file_base_path=self._config_base_path, + base64_file_content=False).as_data() + if token: + self.token = "Bearer %s" % token + return True + + def _load_user_pass_token(self): + if 'username' in self._user and 'password' in self._user: + self.token = urllib3.util.make_headers( + basic_auth=(self._user['username'] + ':' + + self._user['password'])).get('authorization') + return True + + def _load_cluster_info(self): + if 'server' in self._cluster: + self.host = self._cluster['server'] + if self.host.startswith("https"): + self.ssl_ca_cert = FileOrData( + self._cluster, 'certificate-authority', + file_base_path=self._config_base_path).as_file() + self.cert_file = FileOrData( + self._user, 'client-certificate', + file_base_path=self._config_base_path).as_file() + self.key_file = FileOrData( + self._user, 'client-key', + file_base_path=self._config_base_path).as_file() + if 'insecure-skip-tls-verify' in self._cluster: + self.verify_ssl = not self._cluster['insecure-skip-tls-verify'] + + def _set_config(self): + if 'token' in self.__dict__: + self._client_configuration.api_key['authorization'] = self.token + # copy these keys directly from self to configuration object + keys = ['host', 'ssl_ca_cert', 'cert_file', 'key_file', 'verify_ssl'] + for key in keys: + if key in self.__dict__: + setattr(self._client_configuration, key, getattr(self, key)) + + def load_and_set(self): + self._load_authentication() + self._load_cluster_info() + self._set_config() + + def list_contexts(self): + return [context.value for context in self._config['contexts']] + + @property + def current_context(self): + return self._current_context.value + + +class ConfigNode(object): + """Remembers each config key's path and construct a relevant exception + message in case of missing keys. The assumption is all access keys are + present in a well-formed kube-config.""" + + def __init__(self, name, value): + self.name = name + self.value = value + + def __contains__(self, key): + return key in self.value + + def __len__(self): + return len(self.value) + + def safe_get(self, key): + if (isinstance(self.value, list) and isinstance(key, int) or + key in self.value): + return self.value[key] + + def __getitem__(self, key): + v = self.safe_get(key) + if not v: + raise ConfigException( + 'Invalid kube-config file. Expected key %s in %s' + % (key, self.name)) + if isinstance(v, dict) or isinstance(v, list): + return ConfigNode('%s/%s' % (self.name, key), v) + else: + return v + + def get_with_name(self, name): + if not isinstance(self.value, list): + raise ConfigException( + 'Invalid kube-config file. Expected %s to be a list' + % self.name) + for v in self.value: + if 'name' not in v: + raise ConfigException( + 'Invalid kube-config file. ' + 'Expected all values in %s list to have \'name\' key' + % self.name) + if v['name'] == name: + return ConfigNode('%s[name=%s]' % (self.name, name), v) + raise ConfigException( + 'Invalid kube-config file. ' + 'Expected object with name %s in %s list' % (name, self.name)) + + +def _get_kube_config_loader_for_yaml_file(filename, **kwargs): + with open(filename) as f: + return KubeConfigLoader( + config_dict=yaml.load(f), + config_base_path=os.path.abspath(os.path.dirname(filename)), + **kwargs) + + +def list_kube_config_contexts(config_file=None): + + if config_file is None: + config_file = os.path.expanduser(KUBE_CONFIG_DEFAULT_LOCATION) + + loader = _get_kube_config_loader_for_yaml_file(config_file) + return loader.list_contexts(), loader.current_context + + +def load_kube_config(config_file=None, context=None, + client_configuration=configuration): + """Loads authentication and cluster information from kube-config file + and stores them in kubernetes.client.configuration. + + :param config_file: Name of the kube-config file. + :param context: set the active context. If is set to None, current_context + from config file will be used. + :param client_configuration: The kubernetes.client.ConfigurationObject to + set configs to. + """ + + if config_file is None: + config_file = os.path.expanduser(KUBE_CONFIG_DEFAULT_LOCATION) + + _get_kube_config_loader_for_yaml_file( + config_file, active_context=context, + client_configuration=client_configuration).load_and_set() + + +def new_client_from_config(config_file=None, context=None): + """Loads configuration the same as load_kube_config but returns an ApiClient + to be used with any API object. This will allow the caller to concurrently + talk with multiple clusters.""" + client_config = ConfigurationObject() + load_kube_config(config_file=config_file, context=context, + client_configuration=client_config) + return ApiClient(config=client_config) diff --git a/config/kube_config_test.py b/config/kube_config_test.py new file mode 100644 index 00000000..6784b75b --- /dev/null +++ b/config/kube_config_test.py @@ -0,0 +1,620 @@ +# Copyright 2016 The Kubernetes Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import base64 +import os +import shutil +import tempfile +import unittest + +import yaml +from six import PY3 + +from .config_exception import ConfigException +from .kube_config import (ConfigNode, FileOrData, KubeConfigLoader, + _cleanup_temp_files, _create_temp_file_with_content, + list_kube_config_contexts, load_kube_config, + new_client_from_config) + +BEARER_TOKEN_FORMAT = "Bearer %s" + +NON_EXISTING_FILE = "zz_non_existing_file_472398324" + + +def _base64(string): + return base64.encodestring(string.encode()).decode() + + +TEST_FILE_KEY = "file" +TEST_DATA_KEY = "data" +TEST_FILENAME = "test-filename" + +TEST_DATA = "test-data" +TEST_DATA_BASE64 = _base64(TEST_DATA) + +TEST_ANOTHER_DATA = "another-test-data" +TEST_ANOTHER_DATA_BASE64 = _base64(TEST_ANOTHER_DATA) + +TEST_HOST = "test-host" +TEST_USERNAME = "me" +TEST_PASSWORD = "pass" +# token for me:pass +TEST_BASIC_TOKEN = "Basic bWU6cGFzcw==" + +TEST_SSL_HOST = "https://test-host" +TEST_CERTIFICATE_AUTH = "cert-auth" +TEST_CERTIFICATE_AUTH_BASE64 = _base64(TEST_CERTIFICATE_AUTH) +TEST_CLIENT_KEY = "client-key" +TEST_CLIENT_KEY_BASE64 = _base64(TEST_CLIENT_KEY) +TEST_CLIENT_CERT = "client-cert" +TEST_CLIENT_CERT_BASE64 = _base64(TEST_CLIENT_CERT) + + +class BaseTestCase(unittest.TestCase): + + def setUp(self): + self._temp_files = [] + + def tearDown(self): + for f in self._temp_files: + os.remove(f) + + def _create_temp_file(self, content=""): + handler, name = tempfile.mkstemp() + self._temp_files.append(name) + os.write(handler, str.encode(content)) + os.close(handler) + return name + + def expect_exception(self, func, message_part): + with self.assertRaises(ConfigException) as context: + func() + self.assertIn(message_part, str(context.exception)) + + +class TestFileOrData(BaseTestCase): + + @staticmethod + def get_file_content(filename): + with open(filename) as f: + return f.read() + + def test_file_given_file(self): + temp_filename = _create_temp_file_with_content(TEST_DATA) + obj = {TEST_FILE_KEY: temp_filename} + t = FileOrData(obj=obj, file_key_name=TEST_FILE_KEY) + self.assertEqual(TEST_DATA, self.get_file_content(t.as_file())) + + def test_file_given_non_existing_file(self): + temp_filename = NON_EXISTING_FILE + obj = {TEST_FILE_KEY: temp_filename} + t = FileOrData(obj=obj, file_key_name=TEST_FILE_KEY) + self.expect_exception(t.as_file, "does not exists") + + def test_file_given_data(self): + obj = {TEST_DATA_KEY: TEST_DATA_BASE64} + t = FileOrData(obj=obj, file_key_name=TEST_FILE_KEY, + data_key_name=TEST_DATA_KEY) + self.assertEqual(TEST_DATA, self.get_file_content(t.as_file())) + + def test_file_given_data_no_base64(self): + obj = {TEST_DATA_KEY: TEST_DATA} + t = FileOrData(obj=obj, file_key_name=TEST_FILE_KEY, + data_key_name=TEST_DATA_KEY, base64_file_content=False) + self.assertEqual(TEST_DATA, self.get_file_content(t.as_file())) + + def test_data_given_data(self): + obj = {TEST_DATA_KEY: TEST_DATA_BASE64} + t = FileOrData(obj=obj, file_key_name=TEST_FILE_KEY, + data_key_name=TEST_DATA_KEY) + self.assertEqual(TEST_DATA_BASE64, t.as_data()) + + def test_data_given_file(self): + obj = { + TEST_FILE_KEY: self._create_temp_file(content=TEST_DATA)} + t = FileOrData(obj=obj, file_key_name=TEST_FILE_KEY) + self.assertEqual(TEST_DATA_BASE64, t.as_data()) + + def test_data_given_file_no_base64(self): + obj = { + TEST_FILE_KEY: self._create_temp_file(content=TEST_DATA)} + t = FileOrData(obj=obj, file_key_name=TEST_FILE_KEY, + base64_file_content=False) + self.assertEqual(TEST_DATA, t.as_data()) + + def test_data_given_file_and_data(self): + obj = { + TEST_DATA_KEY: TEST_DATA_BASE64, + TEST_FILE_KEY: self._create_temp_file( + content=TEST_ANOTHER_DATA)} + t = FileOrData(obj=obj, file_key_name=TEST_FILE_KEY, + data_key_name=TEST_DATA_KEY) + self.assertEqual(TEST_DATA_BASE64, t.as_data()) + + def test_file_given_file_and_data(self): + obj = { + TEST_DATA_KEY: TEST_DATA_BASE64, + TEST_FILE_KEY: self._create_temp_file( + content=TEST_ANOTHER_DATA)} + t = FileOrData(obj=obj, file_key_name=TEST_FILE_KEY, + data_key_name=TEST_DATA_KEY) + self.assertEqual(TEST_DATA, self.get_file_content(t.as_file())) + + def test_file_with_custom_dirname(self): + tempfile = self._create_temp_file(content=TEST_DATA) + tempfile_dir = os.path.dirname(tempfile) + tempfile_basename = os.path.basename(tempfile) + obj = {TEST_FILE_KEY: tempfile_basename} + t = FileOrData(obj=obj, file_key_name=TEST_FILE_KEY, + file_base_path=tempfile_dir) + self.assertEqual(TEST_DATA, self.get_file_content(t.as_file())) + + def test_create_temp_file_with_content(self): + self.assertEqual(TEST_DATA, + self.get_file_content( + _create_temp_file_with_content(TEST_DATA))) + _cleanup_temp_files() + + +class TestConfigNode(BaseTestCase): + + test_obj = {"key1": "test", "key2": ["a", "b", "c"], + "key3": {"inner_key": "inner_value"}, + "with_names": [{"name": "test_name", "value": "test_value"}, + {"name": "test_name2", + "value": {"key1", "test"}}, + {"name": "test_name3", "value": [1, 2, 3]}]} + + def setUp(self): + super(TestConfigNode, self).setUp() + self.node = ConfigNode("test_obj", self.test_obj) + + def test_normal_map_array_operations(self): + self.assertEqual("test", self.node['key1']) + self.assertEqual(4, len(self.node)) + + self.assertEqual("test_obj/key2", self.node['key2'].name) + self.assertEqual(["a", "b", "c"], self.node['key2'].value) + self.assertEqual("b", self.node['key2'][1]) + self.assertEqual(3, len(self.node['key2'])) + + self.assertEqual("test_obj/key3", self.node['key3'].name) + self.assertEqual({"inner_key": "inner_value"}, self.node['key3'].value) + self.assertEqual("inner_value", self.node['key3']["inner_key"]) + self.assertEqual(1, len(self.node['key3'])) + + def test_get_with_name(self): + node = self.node["with_names"] + self.assertEqual( + "test_value", + node.get_with_name("test_name")["value"]) + self.assertTrue( + isinstance(node.get_with_name("test_name2"), ConfigNode)) + self.assertTrue( + isinstance(node.get_with_name("test_name3"), ConfigNode)) + self.assertEqual("test_obj/with_names[name=test_name2]", + node.get_with_name("test_name2").name) + self.assertEqual("test_obj/with_names[name=test_name3]", + node.get_with_name("test_name3").name) + + def test_key_does_not_exists(self): + self.expect_exception(lambda: self.node['not-exists-key'], + "Expected key not-exists-key in test_obj") + self.expect_exception(lambda: self.node['key3']['not-exists-key'], + "Expected key not-exists-key in test_obj/key3") + + def test_get_with_name_on_invalid_object(self): + self.expect_exception( + lambda: self.node['key2'].get_with_name('no-name'), + "Expected all values in test_obj/key2 list to have \'name\' key") + + def test_get_with_name_on_non_list_object(self): + self.expect_exception( + lambda: self.node['key3'].get_with_name('no-name'), + "Expected test_obj/key3 to be a list") + + def test_get_with_name_on_name_does_not_exists(self): + self.expect_exception( + lambda: self.node['with_names'].get_with_name('no-name'), + "Expected object with name no-name in test_obj/with_names list") + + +class FakeConfig: + + FILE_KEYS = ["ssl_ca_cert", "key_file", "cert_file"] + + def __init__(self, token=None, **kwargs): + self.api_key = {} + if token: + self.api_key['authorization'] = token + + self.__dict__.update(kwargs) + + def __eq__(self, other): + if len(self.__dict__) != len(other.__dict__): + return + for k, v in self.__dict__.items(): + if k not in other.__dict__: + return + if k in self.FILE_KEYS: + if v and other.__dict__[k]: + try: + with open(v) as f1, open(other.__dict__[k]) as f2: + if f1.read() != f2.read(): + return + except IOError: + # fall back to only compare filenames in case we are + # testing the passing of filenames to the config + if other.__dict__[k] != v: + return + else: + if other.__dict__[k] != v: + return + else: + if other.__dict__[k] != v: + return + return True + + def __repr__(self): + rep = "\n" + for k, v in self.__dict__.items(): + val = v + if k in self.FILE_KEYS: + try: + with open(v) as f: + val = "FILE: %s" % str.decode(f.read()) + except IOError as e: + val = "ERROR: %s" % str(e) + rep += "\t%s: %s\n" % (k, val) + return "Config(%s\n)" % rep + + +class TestKubeConfigLoader(BaseTestCase): + TEST_KUBE_CONFIG = { + "current-context": "no_user", + "contexts": [ + { + "name": "no_user", + "context": { + "cluster": "default" + } + }, + { + "name": "simple_token", + "context": { + "cluster": "default", + "user": "simple_token" + } + }, + { + "name": "gcp", + "context": { + "cluster": "default", + "user": "gcp" + } + }, + { + "name": "user_pass", + "context": { + "cluster": "default", + "user": "user_pass" + } + }, + { + "name": "ssl", + "context": { + "cluster": "ssl", + "user": "ssl" + } + }, + { + "name": "no_ssl_verification", + "context": { + "cluster": "no_ssl_verification", + "user": "ssl" + } + }, + { + "name": "ssl-no_file", + "context": { + "cluster": "ssl-no_file", + "user": "ssl-no_file" + } + }, + { + "name": "ssl-local-file", + "context": { + "cluster": "ssl-local-file", + "user": "ssl-local-file" + } + }, + ], + "clusters": [ + { + "name": "default", + "cluster": { + "server": TEST_HOST + } + }, + { + "name": "ssl-no_file", + "cluster": { + "server": TEST_SSL_HOST, + "certificate-authority": TEST_CERTIFICATE_AUTH, + } + }, + { + "name": "ssl-local-file", + "cluster": { + "server": TEST_SSL_HOST, + "certificate-authority": "cert_test", + } + }, + { + "name": "ssl", + "cluster": { + "server": TEST_SSL_HOST, + "certificate-authority-data": TEST_CERTIFICATE_AUTH_BASE64, + } + }, + { + "name": "no_ssl_verification", + "cluster": { + "server": TEST_SSL_HOST, + "insecure-skip-tls-verify": "true", + } + }, + ], + "users": [ + { + "name": "simple_token", + "user": { + "token": TEST_DATA_BASE64, + "username": TEST_USERNAME, # should be ignored + "password": TEST_PASSWORD, # should be ignored + } + }, + { + "name": "gcp", + "user": { + "auth-provider": { + "name": "gcp", + "access_token": "not_used", + }, + "token": TEST_DATA_BASE64, # should be ignored + "username": TEST_USERNAME, # should be ignored + "password": TEST_PASSWORD, # should be ignored + } + }, + { + "name": "user_pass", + "user": { + "username": TEST_USERNAME, # should be ignored + "password": TEST_PASSWORD, # should be ignored + } + }, + { + "name": "ssl-no_file", + "user": { + "token": TEST_DATA_BASE64, + "client-certificate": TEST_CLIENT_CERT, + "client-key": TEST_CLIENT_KEY, + } + }, + { + "name": "ssl-local-file", + "user": { + "tokenFile": "token_file", + "client-certificate": "client_cert", + "client-key": "client_key", + } + }, + { + "name": "ssl", + "user": { + "token": TEST_DATA_BASE64, + "client-certificate-data": TEST_CLIENT_CERT_BASE64, + "client-key-data": TEST_CLIENT_KEY_BASE64, + } + }, + ] + } + + def test_no_user_context(self): + expected = FakeConfig(host=TEST_HOST) + actual = FakeConfig() + KubeConfigLoader( + config_dict=self.TEST_KUBE_CONFIG, + active_context="no_user", + client_configuration=actual).load_and_set() + self.assertEqual(expected, actual) + + def test_simple_token(self): + expected = FakeConfig(host=TEST_HOST, + token=BEARER_TOKEN_FORMAT % TEST_DATA_BASE64) + actual = FakeConfig() + KubeConfigLoader( + config_dict=self.TEST_KUBE_CONFIG, + active_context="simple_token", + client_configuration=actual).load_and_set() + self.assertEqual(expected, actual) + + def test_load_user_token(self): + loader = KubeConfigLoader( + config_dict=self.TEST_KUBE_CONFIG, + active_context="simple_token") + self.assertTrue(loader._load_user_token()) + self.assertEqual(BEARER_TOKEN_FORMAT % TEST_DATA_BASE64, loader.token) + + def test_gcp(self): + expected = FakeConfig( + host=TEST_HOST, + token=BEARER_TOKEN_FORMAT % TEST_ANOTHER_DATA_BASE64) + actual = FakeConfig() + KubeConfigLoader( + config_dict=self.TEST_KUBE_CONFIG, + active_context="gcp", + client_configuration=actual, + get_google_credentials=lambda: TEST_ANOTHER_DATA_BASE64) \ + .load_and_set() + self.assertEqual(expected, actual) + + def test_load_gcp_token(self): + loader = KubeConfigLoader( + config_dict=self.TEST_KUBE_CONFIG, + active_context="gcp", + get_google_credentials=lambda: TEST_ANOTHER_DATA_BASE64) + self.assertTrue(loader._load_gcp_token()) + self.assertEqual(BEARER_TOKEN_FORMAT % TEST_ANOTHER_DATA_BASE64, + loader.token) + + def test_user_pass(self): + expected = FakeConfig(host=TEST_HOST, token=TEST_BASIC_TOKEN) + actual = FakeConfig() + KubeConfigLoader( + config_dict=self.TEST_KUBE_CONFIG, + active_context="user_pass", + client_configuration=actual).load_and_set() + self.assertEqual(expected, actual) + + def test_load_user_pass_token(self): + loader = KubeConfigLoader( + config_dict=self.TEST_KUBE_CONFIG, + active_context="user_pass") + self.assertTrue(loader._load_user_pass_token()) + self.assertEqual(TEST_BASIC_TOKEN, loader.token) + + def test_ssl_no_cert_files(self): + actual = FakeConfig() + loader = KubeConfigLoader( + config_dict=self.TEST_KUBE_CONFIG, + active_context="ssl-no_file", + client_configuration=actual) + self.expect_exception(loader.load_and_set, "does not exists") + + def test_ssl(self): + expected = FakeConfig( + host=TEST_SSL_HOST, + token=BEARER_TOKEN_FORMAT % TEST_DATA_BASE64, + cert_file=self._create_temp_file(TEST_CLIENT_CERT), + key_file=self._create_temp_file(TEST_CLIENT_KEY), + ssl_ca_cert=self._create_temp_file(TEST_CERTIFICATE_AUTH) + ) + actual = FakeConfig() + KubeConfigLoader( + config_dict=self.TEST_KUBE_CONFIG, + active_context="ssl", + client_configuration=actual).load_and_set() + self.assertEqual(expected, actual) + + def test_ssl_no_verification(self): + expected = FakeConfig( + host=TEST_SSL_HOST, + token=BEARER_TOKEN_FORMAT % TEST_DATA_BASE64, + cert_file=self._create_temp_file(TEST_CLIENT_CERT), + key_file=self._create_temp_file(TEST_CLIENT_KEY), + verify_ssl=False, + ssl_ca_cert=None, + ) + actual = FakeConfig() + KubeConfigLoader( + config_dict=self.TEST_KUBE_CONFIG, + active_context="no_ssl_verification", + client_configuration=actual).load_and_set() + self.assertEqual(expected, actual) + + def test_list_contexts(self): + loader = KubeConfigLoader( + config_dict=self.TEST_KUBE_CONFIG, + active_context="no_user") + actual_contexts = loader.list_contexts() + expected_contexts = ConfigNode("", self.TEST_KUBE_CONFIG)['contexts'] + for actual in actual_contexts: + expected = expected_contexts.get_with_name(actual['name']) + self.assertEqual(expected.value, actual) + + def test_current_context(self): + loader = KubeConfigLoader(config_dict=self.TEST_KUBE_CONFIG) + expected_contexts = ConfigNode("", self.TEST_KUBE_CONFIG)['contexts'] + self.assertEqual(expected_contexts.get_with_name("no_user").value, + loader.current_context) + + def test_set_active_context(self): + loader = KubeConfigLoader(config_dict=self.TEST_KUBE_CONFIG) + loader.set_active_context("ssl") + expected_contexts = ConfigNode("", self.TEST_KUBE_CONFIG)['contexts'] + self.assertEqual(expected_contexts.get_with_name("ssl").value, + loader.current_context) + + def test_ssl_with_relative_ssl_files(self): + expected = FakeConfig( + host=TEST_SSL_HOST, + token=BEARER_TOKEN_FORMAT % TEST_DATA_BASE64, + cert_file=self._create_temp_file(TEST_CLIENT_CERT), + key_file=self._create_temp_file(TEST_CLIENT_KEY), + ssl_ca_cert=self._create_temp_file(TEST_CERTIFICATE_AUTH) + ) + try: + temp_dir = tempfile.mkdtemp() + actual = FakeConfig() + with open(os.path.join(temp_dir, "cert_test"), "wb") as fd: + fd.write(TEST_CERTIFICATE_AUTH.encode()) + with open(os.path.join(temp_dir, "client_cert"), "wb") as fd: + fd.write(TEST_CLIENT_CERT.encode()) + with open(os.path.join(temp_dir, "client_key"), "wb") as fd: + fd.write(TEST_CLIENT_KEY.encode()) + with open(os.path.join(temp_dir, "token_file"), "wb") as fd: + fd.write(TEST_DATA_BASE64.encode()) + KubeConfigLoader( + config_dict=self.TEST_KUBE_CONFIG, + active_context="ssl-local-file", + config_base_path=temp_dir, + client_configuration=actual).load_and_set() + self.assertEqual(expected, actual) + finally: + shutil.rmtree(temp_dir) + + def test_load_kube_config(self): + expected = FakeConfig(host=TEST_HOST, + token=BEARER_TOKEN_FORMAT % TEST_DATA_BASE64) + config_file = self._create_temp_file(yaml.dump(self.TEST_KUBE_CONFIG)) + actual = FakeConfig() + load_kube_config(config_file=config_file, context="simple_token", + client_configuration=actual) + self.assertEqual(expected, actual) + + def test_list_kube_config_contexts(self): + config_file = self._create_temp_file(yaml.dump(self.TEST_KUBE_CONFIG)) + contexts, active_context = list_kube_config_contexts( + config_file=config_file) + self.assertDictEqual(self.TEST_KUBE_CONFIG['contexts'][0], + active_context) + if PY3: + self.assertCountEqual(self.TEST_KUBE_CONFIG['contexts'], + contexts) + else: + self.assertItemsEqual(self.TEST_KUBE_CONFIG['contexts'], + contexts) + + def test_new_client_from_config(self): + config_file = self._create_temp_file(yaml.dump(self.TEST_KUBE_CONFIG)) + client = new_client_from_config( + config_file=config_file, context="simple_token") + self.assertEqual(TEST_HOST, client.config.host) + self.assertEqual(BEARER_TOKEN_FORMAT % TEST_DATA_BASE64, + client.config.api_key['authorization']) + + +if __name__ == '__main__': + unittest.main() diff --git a/configuration.py b/configuration.py new file mode 100644 index 00000000..bf0fd733 --- /dev/null +++ b/configuration.py @@ -0,0 +1,237 @@ +# coding: utf-8 + +""" + Kubernetes + + No description provided (generated by Swagger Codegen https://github.com/swagger-api/swagger-codegen) + + OpenAPI spec version: v1.5.0-snapshot + + Generated by: https://github.com/swagger-api/swagger-codegen.git + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" + +from __future__ import absolute_import + +import urllib3 + +import sys +import logging + +from six import iteritems +from six.moves import http_client as httplib + + +class ConfigurationObject(object): + """ + NOTE: This class is auto generated by the swagger code generator program. + Ref: https://github.com/swagger-api/swagger-codegen + Do not edit the class manually. + """ + + def __init__(self): + """ + Constructor + """ + # Default Base url + self.host = "https://localhost" + # Default api client + self.api_client = None + # Temp file folder for downloading files + self.temp_folder_path = None + + # Authentication Settings + # dict to store API key(s) + self.api_key = {} + # dict to store API prefix (e.g. Bearer) + self.api_key_prefix = {} + # Username for HTTP basic authentication + self.username = "" + # Password for HTTP basic authentication + self.password = "" + + # Logging Settings + self.logger = {} + self.logger["package_logger"] = logging.getLogger("client") + self.logger["urllib3_logger"] = logging.getLogger("urllib3") + # Log format + self.logger_format = '%(asctime)s %(levelname)s %(message)s' + # Log stream handler + self.logger_stream_handler = None + # Log file handler + self.logger_file_handler = None + # Debug file location + self.logger_file = None + # Debug switch + self.debug = False + + # SSL/TLS verification + # Set this to false to skip verifying SSL certificate when calling API from https server. + self.verify_ssl = True + # Set this to customize the certificate file to verify the peer. + self.ssl_ca_cert = None + # client certificate file + self.cert_file = None + # client key file + self.key_file = None + # check host name + # Set this to True/False to enable/disable SSL hostname verification. + self.assert_hostname = None + + @property + def logger_file(self): + """ + Gets the logger_file. + """ + return self.__logger_file + + @logger_file.setter + def logger_file(self, value): + """ + Sets the logger_file. + + If the logger_file is None, then add stream handler and remove file handler. + Otherwise, add file handler and remove stream handler. + + :param value: The logger_file path. + :type: str + """ + self.__logger_file = value + if self.__logger_file: + # If set logging file, + # then add file handler and remove stream handler. + self.logger_file_handler = logging.FileHandler(self.__logger_file) + self.logger_file_handler.setFormatter(self.logger_formatter) + for _, logger in iteritems(self.logger): + logger.addHandler(self.logger_file_handler) + if self.logger_stream_handler: + logger.removeHandler(self.logger_stream_handler) + else: + # If not set logging file, + # then add stream handler and remove file handler. + self.logger_stream_handler = logging.StreamHandler() + self.logger_stream_handler.setFormatter(self.logger_formatter) + for _, logger in iteritems(self.logger): + logger.addHandler(self.logger_stream_handler) + if self.logger_file_handler: + logger.removeHandler(self.logger_file_handler) + + @property + def debug(self): + """ + Gets the debug status. + """ + return self.__debug + + @debug.setter + def debug(self, value): + """ + Sets the debug status. + + :param value: The debug status, True or False. + :type: bool + """ + self.__debug = value + if self.__debug: + # if debug status is True, turn on debug logging + for _, logger in iteritems(self.logger): + logger.setLevel(logging.DEBUG) + # turn on httplib debug + httplib.HTTPConnection.debuglevel = 1 + else: + # if debug status is False, turn off debug logging, + # setting log level to default `logging.WARNING` + for _, logger in iteritems(self.logger): + logger.setLevel(logging.WARNING) + # turn off httplib debug + httplib.HTTPConnection.debuglevel = 0 + + @property + def logger_format(self): + """ + Gets the logger_format. + """ + return self.__logger_format + + @logger_format.setter + def logger_format(self, value): + """ + Sets the logger_format. + + The logger_formatter will be updated when sets logger_format. + + :param value: The format string. + :type: str + """ + self.__logger_format = value + self.logger_formatter = logging.Formatter(self.__logger_format) + + def get_api_key_with_prefix(self, identifier): + """ + Gets API key (with prefix if set). + + :param identifier: The identifier of apiKey. + :return: The token for api key authentication. + """ + if self.api_key.get(identifier) and self.api_key_prefix.get(identifier): + return self.api_key_prefix[identifier] + ' ' + self.api_key[identifier] + elif self.api_key.get(identifier): + return self.api_key[identifier] + + def get_basic_auth_token(self): + """ + Gets HTTP basic authentication header (string). + + :return: The token for basic HTTP authentication. + """ + return urllib3.util.make_headers(basic_auth=self.username + ':' + self.password)\ + .get('authorization') + + def auth_settings(self): + """ + Gets Auth Settings dict for api client. + + :return: The Auth Settings information dict. + """ + return { + 'BearerToken': + { + 'type': 'api_key', + 'in': 'header', + 'key': 'authorization', + 'value': self.get_api_key_with_prefix('authorization') + }, + + } + + def to_debug_report(self): + """ + Gets the essential information for debugging. + + :return: The report for debugging. + """ + return "Python SDK Debug Report:\n"\ + "OS: {env}\n"\ + "Python Version: {pyversion}\n"\ + "Version of the API: v1.5.0-snapshot\n"\ + "SDK Package Version: 1.0.0-snapshot".\ + format(env=sys.platform, pyversion=sys.version) + + +configuration = ConfigurationObject() + + +def Configuration(): + """Simulate a singelton Configuration object for backward compatibility.""" + return configuration diff --git a/rest.py b/rest.py new file mode 100644 index 00000000..8b3a5dab --- /dev/null +++ b/rest.py @@ -0,0 +1,324 @@ +# coding: utf-8 + +""" + Kubernetes + + No description provided (generated by Swagger Codegen https://github.com/swagger-api/swagger-codegen) + + OpenAPI spec version: v1.5.0-snapshot + + Generated by: https://github.com/swagger-api/swagger-codegen.git + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" + +from __future__ import absolute_import + +import io +import json +import ssl +import certifi +import logging +import re + +# python 2 and python 3 compatibility library +from six import PY3 +from six.moves.urllib.parse import urlencode + +from .configuration import configuration + +try: + import urllib3 +except ImportError: + raise ImportError('Swagger python client requires urllib3.') + + +logger = logging.getLogger(__name__) + + +class RESTResponse(io.IOBase): + + def __init__(self, resp): + self.urllib3_response = resp + self.status = resp.status + self.reason = resp.reason + self.data = resp.data + + def getheaders(self): + """ + Returns a dictionary of the response headers. + """ + return self.urllib3_response.getheaders() + + def getheader(self, name, default=None): + """ + Returns a given response header. + """ + return self.urllib3_response.getheader(name, default) + + +class RESTClientObject(object): + + def __init__(self, pools_size=4, config=configuration): + # urllib3.PoolManager will pass all kw parameters to connectionpool + # https://github.com/shazow/urllib3/blob/f9409436f83aeb79fbaf090181cd81b784f1b8ce/urllib3/poolmanager.py#L75 + # https://github.com/shazow/urllib3/blob/f9409436f83aeb79fbaf090181cd81b784f1b8ce/urllib3/connectionpool.py#L680 + # ca_certs vs cert_file vs key_file + # http://stackoverflow.com/a/23957365/2985775 + + # cert_reqs + if config.verify_ssl: + cert_reqs = ssl.CERT_REQUIRED + else: + cert_reqs = ssl.CERT_NONE + + # ca_certs + if config.ssl_ca_cert: + ca_certs = config.ssl_ca_cert + else: + # if not set certificate file, use Mozilla's root certificates. + ca_certs = certifi.where() + + # cert_file + cert_file = config.cert_file + + # key file + key_file = config.key_file + + kwargs = { + 'num_pools': pools_size, + 'cert_reqs': cert_reqs, + 'ca_certs': ca_certs, + 'cert_file': cert_file, + 'key_file': key_file, + } + + if config.assert_hostname is not None: + kwargs['assert_hostname'] = config.assert_hostname + + # https pool manager + self.pool_manager = urllib3.PoolManager( + **kwargs + ) + + def request(self, method, url, query_params=None, headers=None, + body=None, post_params=None, _preload_content=True, _request_timeout=None): + """ + :param method: http request method + :param url: http request url + :param query_params: query parameters in the url + :param headers: http request headers + :param body: request json body, for `application/json` + :param post_params: request post parameters, + `application/x-www-form-urlencoded` + and `multipart/form-data` + :param _preload_content: if False, the urllib3.HTTPResponse object will be returned without + reading/decoding response data. Default is True. + :param _request_timeout: timeout setting for this request. If one number provided, it will be total request + timeout. It can also be a pair (tuple) of (connection, read) timeouts. + """ + method = method.upper() + assert method in ['GET', 'HEAD', 'DELETE', 'POST', 'PUT', 'PATCH', 'OPTIONS'] + + if post_params and body: + raise ValueError( + "body parameter cannot be used with post_params parameter." + ) + + post_params = post_params or {} + headers = headers or {} + + timeout = None + if _request_timeout: + if isinstance(_request_timeout, (int, ) if PY3 else (int, long)): + timeout = urllib3.Timeout(total=_request_timeout) + elif isinstance(_request_timeout, tuple) and len(_request_timeout) == 2: + timeout = urllib3.Timeout(connect=_request_timeout[0], read=_request_timeout[1]) + + if 'Content-Type' not in headers: + headers['Content-Type'] = 'application/json' + + try: + # For `POST`, `PUT`, `PATCH`, `OPTIONS`, `DELETE` + if method in ['POST', 'PUT', 'PATCH', 'OPTIONS', 'DELETE']: + if query_params: + url += '?' + urlencode(query_params) + if headers['Content-Type'] == 'application/json-patch+json': + if not isinstance(body, list): + headers['Content-Type'] = 'application/strategic-merge-patch+json' + request_body = None + if body: + request_body = json.dumps(body) + r = self.pool_manager.request(method, url, + body=request_body, + preload_content=_preload_content, + timeout=timeout, + headers=headers) + elif re.search('json', headers['Content-Type'], re.IGNORECASE): + request_body = None + if body: + request_body = json.dumps(body) + r = self.pool_manager.request(method, url, + body=request_body, + preload_content=_preload_content, + timeout=timeout, + headers=headers) + elif headers['Content-Type'] == 'application/x-www-form-urlencoded': + r = self.pool_manager.request(method, url, + fields=post_params, + encode_multipart=False, + preload_content=_preload_content, + timeout=timeout, + headers=headers) + elif headers['Content-Type'] == 'multipart/form-data': + # must del headers['Content-Type'], or the correct Content-Type + # which generated by urllib3 will be overwritten. + del headers['Content-Type'] + r = self.pool_manager.request(method, url, + fields=post_params, + encode_multipart=True, + preload_content=_preload_content, + timeout=timeout, + headers=headers) + # Pass a `string` parameter directly in the body to support + # other content types than Json when `body` argument is provided + # in serialized form + elif isinstance(body, str): + request_body = body + r = self.pool_manager.request(method, url, + body=request_body, + preload_content=_preload_content, + timeout=timeout, + headers=headers) + else: + # Cannot generate the request from given parameters + msg = """Cannot prepare a request message for provided arguments. + Please check that your arguments match declared content type.""" + raise ApiException(status=0, reason=msg) + # For `GET`, `HEAD` + else: + r = self.pool_manager.request(method, url, + fields=query_params, + preload_content=_preload_content, + timeout=timeout, + headers=headers) + except urllib3.exceptions.SSLError as e: + msg = "{0}\n{1}".format(type(e).__name__, str(e)) + raise ApiException(status=0, reason=msg) + + if _preload_content: + r = RESTResponse(r) + + # In the python 3, the response.data is bytes. + # we need to decode it to string. + if PY3: + r.data = r.data.decode('utf8') + + # log response body + logger.debug("response body: %s", r.data) + + if r.status not in range(200, 206): + raise ApiException(http_resp=r) + + return r + + def GET(self, url, headers=None, query_params=None, _preload_content=True, _request_timeout=None): + return self.request("GET", url, + headers=headers, + _preload_content=_preload_content, + _request_timeout=_request_timeout, + query_params=query_params) + + def HEAD(self, url, headers=None, query_params=None, _preload_content=True, _request_timeout=None): + return self.request("HEAD", url, + headers=headers, + _preload_content=_preload_content, + _request_timeout=_request_timeout, + query_params=query_params) + + def OPTIONS(self, url, headers=None, query_params=None, post_params=None, body=None, _preload_content=True, + _request_timeout=None): + return self.request("OPTIONS", url, + headers=headers, + query_params=query_params, + post_params=post_params, + _preload_content=_preload_content, + _request_timeout=_request_timeout, + body=body) + + def DELETE(self, url, headers=None, query_params=None, body=None, _preload_content=True, _request_timeout=None): + return self.request("DELETE", url, + headers=headers, + query_params=query_params, + _preload_content=_preload_content, + _request_timeout=_request_timeout, + body=body) + + def POST(self, url, headers=None, query_params=None, post_params=None, body=None, _preload_content=True, + _request_timeout=None): + return self.request("POST", url, + headers=headers, + query_params=query_params, + post_params=post_params, + _preload_content=_preload_content, + _request_timeout=_request_timeout, + body=body) + + def PUT(self, url, headers=None, query_params=None, post_params=None, body=None, _preload_content=True, + _request_timeout=None): + return self.request("PUT", url, + headers=headers, + query_params=query_params, + post_params=post_params, + _preload_content=_preload_content, + _request_timeout=_request_timeout, + body=body) + + def PATCH(self, url, headers=None, query_params=None, post_params=None, body=None, _preload_content=True, + _request_timeout=None): + return self.request("PATCH", url, + headers=headers, + query_params=query_params, + post_params=post_params, + _preload_content=_preload_content, + _request_timeout=_request_timeout, + body=body) + + +class ApiException(Exception): + + def __init__(self, status=None, reason=None, http_resp=None): + if http_resp: + self.status = http_resp.status + self.reason = http_resp.reason + self.body = http_resp.data + self.headers = http_resp.getheaders() + else: + self.status = status + self.reason = reason + self.body = None + self.headers = None + + def __str__(self): + """ + Custom error messages for exception + """ + error_message = "({0})\n"\ + "Reason: {1}\n".format(self.status, self.reason) + if self.headers: + error_message += "HTTP response headers: {0}\n".format(self.headers) + + if self.body: + error_message += "HTTP response body: {0}\n".format(self.body) + + return error_message diff --git a/watch/__init__.py b/watch/__init__.py new file mode 100644 index 00000000..ca9ac069 --- /dev/null +++ b/watch/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2016 The Kubernetes Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .watch import Watch diff --git a/watch/watch b/watch/watch new file mode 120000 index 00000000..1655a60f --- /dev/null +++ b/watch/watch @@ -0,0 +1 @@ +watch \ No newline at end of file diff --git a/watch/watch.py b/watch/watch.py new file mode 100644 index 00000000..9dd7af79 --- /dev/null +++ b/watch/watch.py @@ -0,0 +1,123 @@ +# Copyright 2016 The Kubernetes Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import pydoc + +from kubernetes import client + +PYDOC_RETURN_LABEL = ":return:" + +# Removing this suffix from return type name should give us event's object +# type. e.g., if list_namespaces() returns "NamespaceList" type, +# then list_namespaces(watch=true) returns a stream of events with objects +# of type "Namespace". In case this assumption is not true, user should +# provide return_type to Watch class's __init__. +TYPE_LIST_SUFFIX = "List" + + +class SimpleNamespace: + + def __init__(self, **kwargs): + self.__dict__.update(kwargs) + + +def _find_return_type(func): + for line in pydoc.getdoc(func).splitlines(): + if line.startswith(PYDOC_RETURN_LABEL): + return line[len(PYDOC_RETURN_LABEL):].strip() + return "" + + +def iter_resp_lines(resp): + prev = "" + for seg in resp.read_chunked(decode_content=False): + if isinstance(seg, bytes): + seg = seg.decode('utf8') + seg = prev + seg + lines = seg.split("\n") + if not seg.endswith("\n"): + prev = lines[-1] + lines = lines[:-1] + else: + prev = "" + for line in lines: + if line: + yield line + + +class Watch(object): + + def __init__(self, return_type=None): + self._raw_return_type = return_type + self._stop = False + self._api_client = client.ApiClient() + + def stop(self): + self._stop = True + + def get_return_type(self, func): + if self._raw_return_type: + return self._raw_return_type + return_type = _find_return_type(func) + if return_type.endswith(TYPE_LIST_SUFFIX): + return return_type[:-len(TYPE_LIST_SUFFIX)] + return return_type + + def unmarshal_event(self, data, return_type): + js = json.loads(data) + js['raw_object'] = js['object'] + if return_type: + obj = SimpleNamespace(data=json.dumps(js['raw_object'])) + js['object'] = self._api_client.deserialize(obj, return_type) + return js + + def stream(self, func, *args, **kwargs): + """Watch an API resource and stream the result back via a generator. + + :param func: The API function pointer. Any parameter to the function + can be passed after this parameter. + + :return: Event object with these keys: + 'type': The type of event such as "ADDED", "DELETED", etc. + 'raw_object': a dict representing the watched object. + 'object': A model representation of raw_object. The name of + model will be determined based on + the func's doc string. If it cannot be determined, + 'object' value will be the same as 'raw_object'. + + Example: + v1 = kubernetes.client.CoreV1Api() + watch = kubernetes.watch.Watch() + for e in watch.stream(v1.list_namespace, resource_version=1127): + type = e['type'] + object = e['object'] # object is one of type return_type + raw_object = e['raw_object'] # raw_object is a dict + ... + if should_stop: + watch.stop() + """ + + return_type = self.get_return_type(func) + kwargs['watch'] = True + kwargs['_preload_content'] = False + resp = func(*args, **kwargs) + try: + for line in iter_resp_lines(resp): + yield self.unmarshal_event(line, return_type) + if self._stop: + break + finally: + resp.close() + resp.release_conn() diff --git a/watch/watch_test.py b/watch/watch_test.py new file mode 100644 index 00000000..0f441bef --- /dev/null +++ b/watch/watch_test.py @@ -0,0 +1,102 @@ +# Copyright 2016 The Kubernetes Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from mock import Mock + +from .watch import Watch + + +class WatchTests(unittest.TestCase): + + def test_watch_with_decode(self): + fake_resp = Mock() + fake_resp.close = Mock() + fake_resp.release_conn = Mock() + fake_resp.read_chunked = Mock( + return_value=[ + '{"type": "ADDED", "object": {"metadata": {"name": "test1"}' + ',"spec": {}, "status": {}}}\n', + '{"type": "ADDED", "object": {"metadata": {"name": "test2"}' + ',"spec": {}, "sta', + 'tus": {}}}\n' + '{"type": "ADDED", "object": {"metadata": {"name": "test3"},' + '"spec": {}, "status": {}}}\n', + 'should_not_happened\n']) + + fake_api = Mock() + fake_api.get_namespaces = Mock(return_value=fake_resp) + fake_api.get_namespaces.__doc__ = ':return: V1NamespaceList' + + w = Watch() + count = 1 + for e in w.stream(fake_api.get_namespaces): + self.assertEqual("ADDED", e['type']) + # make sure decoder worked and we got a model with the right name + self.assertEqual("test%d" % count, e['object'].metadata.name) + count += 1 + # make sure we can stop the watch and the last event with won't be + # returned + if count == 4: + w.stop() + + fake_api.get_namespaces.assert_called_once_with( + _preload_content=False, watch=True) + fake_resp.read_chunked.assert_called_once_with(decode_content=False) + fake_resp.close.assert_called_once() + fake_resp.release_conn.assert_called_once() + + def test_unmarshal_with_float_object(self): + w = Watch() + event = w.unmarshal_event('{"type": "ADDED", "object": 1}', 'float') + self.assertEqual("ADDED", event['type']) + self.assertEqual(1.0, event['object']) + self.assertTrue(isinstance(event['object'], float)) + self.assertEqual(1, event['raw_object']) + + def test_unmarshal_with_no_return_type(self): + w = Watch() + event = w.unmarshal_event('{"type": "ADDED", "object": ["test1"]}', + None) + self.assertEqual("ADDED", event['type']) + self.assertEqual(["test1"], event['object']) + self.assertEqual(["test1"], event['raw_object']) + + def test_watch_with_exception(self): + fake_resp = Mock() + fake_resp.close = Mock() + fake_resp.release_conn = Mock() + fake_resp.read_chunked = Mock(side_effect=KeyError('expected')) + + fake_api = Mock() + fake_api.get_thing = Mock(return_value=fake_resp) + + w = Watch() + try: + for _ in w.stream(fake_api.get_thing): + self.fail(self, "Should fail on exception.") + except KeyError: + pass + # expected + + fake_api.get_thing.assert_called_once_with( + _preload_content=False, watch=True) + fake_resp.read_chunked.assert_called_once_with(decode_content=False) + fake_resp.close.assert_called_once() + fake_resp.release_conn.assert_called_once() + + +if __name__ == '__main__': + unittest.main()