diff --git a/adafruit_jwt.py b/adafruit_jwt.py index 775e152..d8124d4 100644 --- a/adafruit_jwt.py +++ b/adafruit_jwt.py @@ -77,8 +77,8 @@ def validate(jwt): # Attempt to decode JOSE header try: jose_header = STRING_TOOLS.urlsafe_b64decode(jwt.split(".")[0]) - except UnicodeError: - raise UnicodeError("Unable to decode JOSE header.") + except UnicodeError as unicode_error: + raise UnicodeError("Unable to decode JOSE header.") from unicode_error # Check for typ and alg in decoded JOSE header if "typ" not in jose_header: raise TypeError("JOSE Header does not contain required type key.") @@ -87,17 +87,19 @@ def validate(jwt): # Attempt to decode claim set try: claims = json.loads(STRING_TOOLS.urlsafe_b64decode(jwt.split(".")[1])) - except UnicodeError: - raise UnicodeError("Invalid claims encoding.") + except UnicodeError as unicode_error: + raise UnicodeError("Invalid claims encoding.") from unicode_error if not hasattr(claims, "keys"): raise TypeError("Provided claims is not a JSON dict. object") return (jose_header, claims) @staticmethod - def generate(claims, private_key_data=None, algo=None): + def generate(claims, private_key_data=None, algo=None, headers=None): """Generates and returns a new JSON Web Token. :param dict claims: JWT claims set :param str private_key_data: Decoded RSA private key data. + :param str algo: algorithm to be used. One of None, RS256, RS384 or RS512. + :param dict headers: additional headers for the claim. :rtype: str """ # Allow for unencrypted JWTs @@ -108,6 +110,8 @@ def generate(claims, private_key_data=None, algo=None): # Create the JOSE Header # https://tools.ietf.org/html/rfc7519#section-5 jose_header = {"typ": "JWT", "alg": algo} + if headers: + jose_header.update(headers) payload = "{}.{}".format( STRING_TOOLS.urlsafe_b64encode(json.dumps(jose_header).encode("utf-8")), STRING_TOOLS.urlsafe_b64encode(json.dumps(claims).encode("utf-8")), @@ -139,8 +143,7 @@ def generate(claims, private_key_data=None, algo=None): # pylint: disable=invalid-name class STRING_TOOLS: - """Tools and helpers for URL-safe string encoding. - """ + """Tools and helpers for URL-safe string encoding.""" # Some strings for ctype-style character classification whitespace = " \t\n\r\v\f" @@ -179,8 +182,10 @@ def _bytes_from_decode_data(str_data): if isinstance(str_data, str): try: return str_data.encode("ascii") - except: - raise ValueError("string argument should contain only ASCII characters") + except BaseException as error: + raise ValueError( + "string argument should contain only ASCII characters" + ) from error elif isinstance(str_data, bit_types): return str_data else: