diff --git a/adafruit_httpserver/__init__.py b/adafruit_httpserver/__init__.py index 5bfe1cb..5fd8099 100644 --- a/adafruit_httpserver/__init__.py +++ b/adafruit_httpserver/__init__.py @@ -25,6 +25,7 @@ from .authentication import ( Basic, + Token, Bearer, check_authentication, require_authentication, @@ -77,6 +78,8 @@ ACCEPTED_202, NO_CONTENT_204, PARTIAL_CONTENT_206, + MOVED_PERMANENTLY_301, + FOUND_302, TEMPORARY_REDIRECT_307, PERMANENT_REDIRECT_308, BAD_REQUEST_400, diff --git a/adafruit_httpserver/authentication.py b/adafruit_httpserver/authentication.py index a57564d..aa8f60d 100644 --- a/adafruit_httpserver/authentication.py +++ b/adafruit_httpserver/authentication.py @@ -28,17 +28,27 @@ def __str__(self) -> str: return f"Basic {self._value}" -class Bearer: - """Represents HTTP Bearer Token Authentication.""" +class Token: + """Represents HTTP Token Authentication.""" + + prefix = "Token" def __init__(self, token: str) -> None: self._value = token def __str__(self) -> str: - return f"Bearer {self._value}" + return f"{self.prefix} {self._value}" + + +class Bearer(Token): # pylint: disable=too-few-public-methods + """Represents HTTP Bearer Token Authentication.""" + + prefix = "Bearer" -def check_authentication(request: Request, auths: List[Union[Basic, Bearer]]) -> bool: +def check_authentication( + request: Request, auths: List[Union[Basic, Token, Bearer]] +) -> bool: """ Returns ``True`` if request is authorized by any of the authentications, ``False`` otherwise. @@ -47,7 +57,7 @@ def check_authentication(request: Request, auths: List[Union[Basic, Bearer]]) -> check_authentication(request, [Basic("username", "password")]) """ - auth_header = request.headers.get("Authorization") + auth_header = request.headers.get_directive("Authorization") if auth_header is None: return False @@ -55,7 +65,9 @@ def check_authentication(request: Request, auths: List[Union[Basic, Bearer]]) -> return any(auth_header == str(auth) for auth in auths) -def require_authentication(request: Request, auths: List[Union[Basic, Bearer]]) -> None: +def require_authentication( + request: Request, auths: List[Union[Basic, Token, Bearer]] +) -> None: """ Checks if the request is authorized and raises ``AuthenticationError`` if not. diff --git a/adafruit_httpserver/headers.py b/adafruit_httpserver/headers.py index 6ac13c5..fca38d4 100644 --- a/adafruit_httpserver/headers.py +++ b/adafruit_httpserver/headers.py @@ -8,12 +8,14 @@ """ try: - from typing import Dict, Tuple + from typing import Dict, List, Union except ImportError: pass +from .interfaces import _IFieldStorage -class Headers: + +class Headers(_IFieldStorage): """ A dict-like class for storing HTTP headers. @@ -23,6 +25,8 @@ class Headers: Examples:: + headers = Headers("Content-Type: text/html\\r\\nContent-Length: 1024\\r\\n") + # or headers = Headers({"Content-Type": "text/html", "Content-Length": "1024"}) len(headers) @@ -45,60 +49,101 @@ class Headers: # True """ - _storage: Dict[str, Tuple[str, str]] + _storage: Dict[str, List[str]] + + def __init__(self, headers: Union[str, Dict[str, str]] = None) -> None: + self._storage = {} - def __init__(self, headers: Dict[str, str] = None) -> None: - headers = headers or {} + if isinstance(headers, str): + for header_line in headers.strip().splitlines(): + name, value = header_line.split(": ", 1) + self.add(name, value) + else: + for key, value in (headers or {}).items(): + self.add(key, value) - self._storage = {key.lower(): [key, value] for key, value in headers.items()} + def add(self, field_name: str, value: str): + """ + Adds a header with the given field name and value. + Allows adding multiple headers with the same name. + """ + self._add_field_value(field_name.lower(), value) - def get(self, name: str, default: str = None): + def get(self, field_name: str, default: str = None) -> Union[str, None]: """Returns the value for the given header name, or default if not found.""" - return self._storage.get(name.lower(), [None, default])[1] + return super().get(field_name.lower(), default) - def setdefault(self, name: str, default: str = None): - """Sets the value for the given header name if it does not exist.""" - return self._storage.setdefault(name.lower(), [name, default])[1] + def get_list(self, field_name: str) -> List[str]: + """Get the list of values of a field.""" + return super().get_list(field_name.lower()) + + def get_directive(self, name: str, default: str = None) -> Union[str, None]: + """ + Returns the main value (directive) for the given header name, or default if not found. + + Example:: + + headers = Headers({"Content-Type": "text/html; charset=utf-8"}) + headers.get_directive("Content-Type") + # 'text/html' + """ + + header_value = self.get(name) + if header_value is None: + return default + return header_value.split(";")[0].strip('" ') + + def get_parameter( + self, name: str, parameter: str, default: str = None + ) -> Union[str, None]: + """ + Returns the value of the given parameter for the given header name, or default if not found. - def items(self): - """Returns a list of (name, value) tuples.""" - return dict(self._storage.values()).items() + Example:: - def keys(self): - """Returns a list of header names.""" - return dict(self._storage.values()).keys() + headers = Headers({"Content-Type": "text/html; charset=utf-8"}) + headers.get_parameter("Content-Type", "charset") + # 'utf-8' + """ - def values(self): - """Returns a list of header values.""" - return dict(self._storage.values()).values() + header_value = self.get(name) + if header_value is None: + return default + for header_parameter in header_value.split(";"): + if header_parameter.strip().startswith(parameter): + return header_parameter.strip().split("=")[1].strip('" ') + return default + + def set(self, name: str, value: str): + """Sets the value for the given header name.""" + self._storage[name.lower()] = [value] + + def setdefault(self, name: str, default: str = None): + """Sets the value for the given header name if it does not exist.""" + return self._storage.setdefault(name.lower(), [default]) def update(self, headers: Dict[str, str]): """Updates the headers with the given dict.""" return self._storage.update( - {key.lower(): [key, value] for key, value in headers.items()} + {key.lower(): [value] for key, value in headers.items()} ) def copy(self): """Returns a copy of the headers.""" - return Headers(dict(self._storage.values())) + return Headers( + "\r\n".join( + f"{key}: {value}" for key in self.fields for value in self.get_list(key) + ) + ) def __getitem__(self, name: str): - return self._storage[name.lower()][1] + return super().__getitem__(name.lower()) def __setitem__(self, name: str, value: str): - self._storage[name.lower()] = [name, value] + self._storage[name.lower()] = [value] def __delitem__(self, name: str): del self._storage[name.lower()] - def __iter__(self): - return iter(dict(self._storage.values())) - - def __len__(self): - return len(self._storage) - def __contains__(self, key: str): - return key.lower() in self._storage.keys() - - def __repr__(self): - return f"{self.__class__.__name__}({dict(self._storage.values())})" + return super().__contains__(key.lower()) diff --git a/adafruit_httpserver/interfaces.py b/adafruit_httpserver/interfaces.py new file mode 100644 index 0000000..48b4e46 --- /dev/null +++ b/adafruit_httpserver/interfaces.py @@ -0,0 +1,109 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 Michał Pokusa +# +# SPDX-License-Identifier: MIT +""" +`adafruit_httpserver.interfaces` +==================================================== +* Author(s): Michał Pokusa +""" + +try: + from typing import List, Dict, Union, Any +except ImportError: + pass + + +class _IFieldStorage: + """Interface with shared methods for QueryParams, FormData and Headers.""" + + _storage: Dict[str, List[Any]] + + def _add_field_value(self, field_name: str, value: Any) -> None: + if field_name not in self._storage: + self._storage[field_name] = [value] + else: + self._storage[field_name].append(value) + + def get(self, field_name: str, default: Any = None) -> Union[Any, None]: + """Get the value of a field.""" + return self._storage.get(field_name, [default])[0] + + def get_list(self, field_name: str) -> List[Any]: + """Get the list of values of a field.""" + return self._storage.get(field_name, []) + + @property + def fields(self): + """Returns a list of field names.""" + return list(self._storage.keys()) + + def items(self): + """Returns a list of (name, value) tuples.""" + return [(key, value) for key in self.fields for value in self.get_list(key)] + + def keys(self): + """Returns a list of header names.""" + return self.fields + + def values(self): + """Returns a list of header values.""" + return [value for key in self.keys() for value in self.get_list(key)] + + def __getitem__(self, field_name: str): + return self._storage[field_name][0] + + def __iter__(self): + return iter(self._storage) + + def __len__(self) -> int: + return len(self._storage) + + def __contains__(self, key: str) -> bool: + return key in self._storage + + def __repr__(self) -> str: + return f"{self.__class__.__name__}({repr(self._storage)})" + + +def _encode_html_entities(value: Union[str, None]) -> Union[str, None]: + """Encodes unsafe HTML characters that could enable XSS attacks.""" + if value is None: + return None + + return ( + str(value) + .replace("&", "&") + .replace("<", "<") + .replace(">", ">") + .replace('"', """) + .replace("'", "'") + ) + + +class _IXSSSafeFieldStorage(_IFieldStorage): + def get( + self, field_name: str, default: Any = None, *, safe=True + ) -> Union[Any, None]: + if safe: + return _encode_html_entities(super().get(field_name, default)) + + _debug_warning_nonencoded_output() + return super().get(field_name, default) + + def get_list(self, field_name: str, *, safe=True) -> List[Any]: + if safe: + return [ + _encode_html_entities(value) for value in super().get_list(field_name) + ] + + _debug_warning_nonencoded_output() + return super().get_list(field_name) + + +def _debug_warning_nonencoded_output(): + """Warns about XSS risks.""" + print( + "WARNING: Setting safe to False makes XSS vulnerabilities possible by " + "allowing access to raw untrusted values submitted by users. If this data is reflected " + "or shown within HTML without proper encoding it could enable Cross-Site Scripting." + ) diff --git a/adafruit_httpserver/request.py b/adafruit_httpserver/request.py index e44b66c..0f1bba7 100644 --- a/adafruit_httpserver/request.py +++ b/adafruit_httpserver/request.py @@ -20,76 +20,18 @@ import json from .headers import Headers +from .interfaces import _IFieldStorage, _IXSSSafeFieldStorage +from .methods import POST, PUT, PATCH, DELETE -class _IFieldStorage: - """Interface with shared methods for QueryParams and FormData.""" - - _storage: Dict[str, List[Union[str, bytes]]] - - def _add_field_value(self, field_name: str, value: Union[str, bytes]) -> None: - if field_name not in self._storage: - self._storage[field_name] = [value] - else: - self._storage[field_name].append(value) - - @staticmethod - def _encode_html_entities(value): - """Encodes unsafe HTML characters.""" - return ( - str(value) - .replace("&", "&") - .replace("<", "<") - .replace(">", ">") - .replace('"', """) - .replace("'", "'") - ) - - def get( - self, field_name: str, default: Any = None, *, safe=True - ) -> Union[str, bytes, None]: - """Get the value of a field.""" - if safe: - return self._encode_html_entities( - self._storage.get(field_name, [default])[0] - ) - - _debug_warning_nonencoded_output() - return self._storage.get(field_name, [default])[0] - - def get_list(self, field_name: str) -> List[Union[str, bytes]]: - """Get the list of values of a field.""" - return self._storage.get(field_name, []) - - @property - def fields(self): - """Returns a list of field names.""" - return list(self._storage.keys()) - - def __getitem__(self, field_name: str): - return self.get(field_name) - - def __iter__(self): - return iter(self._storage) - - def __len__(self): - return len(self._storage) - - def __contains__(self, key: str): - return key in self._storage - - def __repr__(self) -> str: - return f"{self.__class__.__name__}({repr(self._storage)})" - - -class QueryParams(_IFieldStorage): +class QueryParams(_IXSSSafeFieldStorage): """ - Class for parsing and storing GET quer parameters requests. + Class for parsing and storing GET query parameters requests. Examples:: - query_params = QueryParams(b"foo=bar&baz=qux&baz=quux") - # QueryParams({"foo": "bar", "baz": ["qux", "quux"]}) + query_params = QueryParams("foo=bar&baz=qux&baz=quux") + # QueryParams({"foo": ["bar"], "baz": ["qux", "quux"]}) query_params.get("foo") # "bar" query_params["foo"] # "bar" @@ -99,7 +41,7 @@ class QueryParams(_IFieldStorage): query_params.fields # ["foo", "baz"] """ - _storage: Dict[str, List[Union[str, bytes]]] + _storage: Dict[str, List[str]] def __init__(self, query_string: str) -> None: self._storage = {} @@ -111,8 +53,99 @@ def __init__(self, query_string: str) -> None: elif query_param: self._add_field_value(query_param, "") + def _add_field_value(self, field_name: str, value: str) -> None: + super()._add_field_value(field_name, value) + + def get( + self, field_name: str, default: str = None, *, safe=True + ) -> Union[str, None]: + return super().get(field_name, default, safe=safe) -class FormData(_IFieldStorage): + def get_list(self, field_name: str, *, safe=True) -> List[str]: + return super().get_list(field_name, safe=safe) + + +class File: + """ + Class representing a file uploaded via POST. + + Examples:: + + file = request.form_data.files.get("uploaded_file") + # File(filename="foo.txt", content_type="text/plain", size=14) + + file.content + # "Hello, world!\\n" + """ + + filename: str + """Filename of the file.""" + + content_type: str + """Content type of the file.""" + + content: Union[str, bytes] + """Content of the file.""" + + def __init__( + self, filename: str, content_type: str, content: Union[str, bytes] + ) -> None: + self.filename = filename + self.content_type = content_type + self.content = content + + @property + def content_bytes(self) -> bytes: + """ + Content of the file as bytes. + It is recommended to use this instead of ``content`` as it will always return bytes. + + Example:: + + file = request.form_data.files.get("uploaded_file") + + with open(file.filename, "wb") as f: + f.write(file.content_bytes) + """ + return ( + self.content.encode("utf-8") + if isinstance(self.content, str) + else self.content + ) + + @property + def size(self) -> int: + """Length of the file content.""" + return len(self.content) + + def __repr__(self) -> str: + filename, content_type, size = ( + repr(self.filename), + repr(self.content_type), + repr(self.size), + ) + return f"{self.__class__.__name__}({filename=}, {content_type=}, {size=})" + + +class Files(_IFieldStorage): + """Class for files uploaded via POST.""" + + _storage: Dict[str, List[File]] + + def __init__(self) -> None: + self._storage = {} + + def _add_field_value(self, field_name: str, value: File) -> None: + super()._add_field_value(field_name, value) + + def get(self, field_name: str, default: Any = None) -> Union[File, Any, None]: + return super().get(field_name, default) + + def get_list(self, field_name: str) -> List[File]: + return super().get_list(field_name) + + +class FormData(_IXSSSafeFieldStorage): """ Class for parsing and storing form data from POST requests. @@ -124,7 +157,7 @@ class FormData(_IFieldStorage): form_data = FormData(b"foo=bar&baz=qux&baz=quuz", "application/x-www-form-urlencoded") # or form_data = FormData(b"foo=bar\\r\\nbaz=qux\\r\\nbaz=quux", "text/plain") - # FormData({"foo": "bar", "baz": "qux"}) + # FormData({"foo": ["bar"], "baz": ["qux", "quux"]}) form_data.get("foo") # "bar" form_data["foo"] # "bar" @@ -135,26 +168,43 @@ class FormData(_IFieldStorage): """ _storage: Dict[str, List[Union[str, bytes]]] + files: Files - def __init__(self, data: bytes, content_type: str) -> None: - self.content_type = content_type + @staticmethod + def _check_is_supported_content_type(content_type: str) -> None: + return content_type in ( + "application/x-www-form-urlencoded", + "multipart/form-data", + "text/plain", + ) + + def __init__(self, data: bytes, headers: Headers, *, debug: bool = False) -> None: self._storage = {} + self.files = Files() + + self.content_type = headers.get_directive("Content-Type") + content_length = int(headers.get("Content-Length", 0)) + + if debug and not self._check_is_supported_content_type(self.content_type): + _debug_unsupported_form_content_type(self.content_type) - if content_type.startswith("application/x-www-form-urlencoded"): - self._parse_x_www_form_urlencoded(data) + if self.content_type == "application/x-www-form-urlencoded": + self._parse_x_www_form_urlencoded(data[:content_length]) - elif content_type.startswith("multipart/form-data"): - boundary = content_type.split("boundary=")[1] - self._parse_multipart_form_data(data, boundary) + elif self.content_type == "multipart/form-data": + boundary = headers.get_parameter("Content-Type", "boundary") + self._parse_multipart_form_data(data[:content_length], boundary) - elif content_type.startswith("text/plain"): - self._parse_text_plain(data) + elif self.content_type == "text/plain": + self._parse_text_plain(data[:content_length]) def _parse_x_www_form_urlencoded(self, data: bytes) -> None: - decoded_data = data.decode() + if not (decoded_data := data.decode("utf-8").strip("&")): + return for field_name, value in [ - key_value.split("=", 1) for key_value in decoded_data.split("&") + key_value.split("=", 1) if "=" in key_value else (key_value, "") + for key_value in decoded_data.split("&") ]: self._add_field_value(field_name, value) @@ -162,22 +212,51 @@ def _parse_multipart_form_data(self, data: bytes, boundary: str) -> None: blocks = data.split(b"--" + boundary.encode())[1:-1] for block in blocks: - disposition, content = block.split(b"\r\n\r\n", 1) - field_name = disposition.split(b'"', 2)[1].decode() - value = content[:-2] + header_bytes, content_bytes = block.split(b"\r\n\r\n", 1) + headers = Headers(header_bytes.decode("utf-8").strip()) - self._add_field_value(field_name, value) + field_name = headers.get_parameter("Content-Disposition", "name") + filename = headers.get_parameter("Content-Disposition", "filename") + content_type = headers.get_directive("Content-Type", "text/plain") + charset = headers.get_parameter("Content-Type", "charset", "utf-8") + + content = content_bytes[:-2] # remove trailing \r\n + value = content.decode(charset) if content_type == "text/plain" else content + + # TODO: Other text content types (e.g. application/json) should be decoded as well and + + if filename is not None: + self.files._add_field_value( # pylint: disable=protected-access + field_name, File(filename, content_type, value) + ) + else: + self._add_field_value(field_name, value) def _parse_text_plain(self, data: bytes) -> None: - lines = data.split(b"\r\n")[:-1] + lines = data.decode("utf-8").split("\r\n")[:-1] for line in lines: - field_name, value = line.split(b"=", 1) + field_name, value = line.split("=", 1) + + self._add_field_value(field_name, value) - self._add_field_value(field_name.decode(), value.decode()) + def _add_field_value(self, field_name: str, value: Union[str, bytes]) -> None: + super()._add_field_value(field_name, value) + + def get( + self, field_name: str, default: Union[str, bytes] = None, *, safe=True + ) -> Union[str, bytes, None]: + return super().get(field_name, default, safe=safe) + + def get_list(self, field_name: str, *, safe=True) -> List[Union[str, bytes]]: + return super().get_list(field_name, safe=safe) + + def __repr__(self) -> str: + class_name = self.__class__.__name__ + return f"{class_name}({repr(self._storage)}, files={repr(self.files._storage)})" -class Request: +class Request: # pylint: disable=too-many-instance-attributes """ Incoming request, constructed from raw incoming bytes. It is passed as first argument to all route handlers. @@ -231,7 +310,7 @@ class Request: raw_request: bytes """ - Raw 'bytes' that were received from the client. + Raw ``bytes`` that were received from the client. Should **not** be modified directly. """ @@ -248,20 +327,19 @@ def __init__( self.client_address = client_address self.raw_request = raw_request self._form_data = None + self._cookies = None if raw_request is None: raise ValueError("raw_request cannot be None") - header_bytes = self._raw_header_bytes - try: ( self.method, self.path, self.query_params, self.http_version, - ) = self._parse_start_line(header_bytes) - self.headers = self._parse_headers(header_bytes) + self.headers, + ) = self._parse_request_header(self._raw_header_bytes) except Exception as error: raise ValueError("Unparseable raw_request: ", raw_request) from error @@ -274,6 +352,36 @@ def body(self) -> bytes: def body(self, body: bytes) -> None: self.raw_request = self._raw_header_bytes + b"\r\n\r\n" + body + @staticmethod + def _parse_cookies(cookie_header: str) -> None: + """Parse cookies from headers.""" + if cookie_header is None: + return {} + + return { + name: value.strip('"') + for name, value in [ + cookie.strip().split("=", 1) for cookie in cookie_header.split(";") + ] + } + + @property + def cookies(self) -> Dict[str, str]: + """ + Cookies sent with the request. + + Example:: + + request.headers["Cookie"] + # "foo=bar; baz=qux; foo=quux" + + request.cookies + # {"foo": "quux", "baz": "qux"} + """ + if self._cookies is None: + self._cookies = self._parse_cookies(self.headers.get("Cookie")) + return self._cookies + @property def form_data(self) -> Union[FormData, None]: """ @@ -318,12 +426,19 @@ def form_data(self) -> Union[FormData, None]: request.form_data.get_list("baz") # ["qux"] """ if self._form_data is None and self.method == "POST": - self._form_data = FormData(self.body, self.headers["Content-Type"]) + self._form_data = FormData(self.body, self.headers, debug=self.server.debug) return self._form_data def json(self) -> Union[dict, None]: - """Body of the request, as a JSON-decoded dictionary. Only available for POST requests.""" - return json.loads(self.body) if (self.body and self.method == "POST") else None + """ + Body of the request, as a JSON-decoded dictionary. + Only available for POST, PUT, PATCH and DELETE requests. + """ + return ( + json.loads(self.body) + if (self.body and self.method in (POST, PUT, PATCH, DELETE)) + else None + ) @property def _raw_header_bytes(self) -> bytes: @@ -340,12 +455,16 @@ def _raw_body_bytes(self) -> bytes: return self.raw_request[empty_line_index + 4 :] @staticmethod - def _parse_start_line(header_bytes: bytes) -> Tuple[str, str, Dict[str, str], str]: + def _parse_request_header( + header_bytes: bytes, + ) -> Tuple[str, str, QueryParams, str, Headers]: """Parse HTTP Start line to method, path, query_params and http_version.""" - start_line = header_bytes.decode("utf8").splitlines()[0] + start_line, headers_string = ( + header_bytes.decode("utf-8").strip().split("\r\n", 1) + ) - method, path, http_version = start_line.split() + method, path, http_version = start_line.strip().split() if "?" not in path: path += "?" @@ -353,27 +472,15 @@ def _parse_start_line(header_bytes: bytes) -> Tuple[str, str, Dict[str, str], st path, query_string = path.split("?", 1) query_params = QueryParams(query_string) + headers = Headers(headers_string) - return method, path, query_params, http_version - - @staticmethod - def _parse_headers(header_bytes: bytes) -> Headers: - """Parse HTTP headers from raw request.""" - header_lines = header_bytes.decode("utf8").splitlines()[1:] - - return Headers( - { - name: value - for header_line in header_lines - for name, value in [header_line.split(": ", 1)] - } - ) + return method, path, query_params, http_version, headers -def _debug_warning_nonencoded_output(): - """Warns about XSS risks.""" +def _debug_unsupported_form_content_type(content_type: str) -> None: + """Warns when an unsupported form content type is used.""" print( - "WARNING: Setting safe to False makes XSS vulnerabilities possible by " - "allowing access to raw untrusted values submitted by users. If this data is reflected " - "or shown within HTML without proper encoding it could enable Cross-Site Scripting." + f"WARNING: Unsupported Content-Type: {content_type}. " + "Only `application/x-www-form-urlencoded`, `multipart/form-data` and `text/plain` are " + "supported." ) diff --git a/adafruit_httpserver/response.py b/adafruit_httpserver/response.py index 6df7b74..4329a8f 100644 --- a/adafruit_httpserver/response.py +++ b/adafruit_httpserver/response.py @@ -31,6 +31,8 @@ Status, SWITCHING_PROTOCOLS_101, OK_200, + MOVED_PERMANENTLY_301, + FOUND_302, TEMPORARY_REDIRECT_307, PERMANENT_REDIRECT_308, ) @@ -58,6 +60,7 @@ def __init__( # pylint: disable=too-many-arguments *, status: Union[Status, Tuple[int, str]] = OK_200, headers: Union[Headers, Dict[str, str]] = None, + cookies: Dict[str, str] = None, content_type: str = None, ) -> None: """ @@ -65,6 +68,7 @@ def __init__( # pylint: disable=too-many-arguments :param str body: Body of response. Defaults to empty string. :param Status status: Status code and text. Defaults to 200 OK. :param Headers headers: Headers to include in response. Defaults to empty dict. + :param Dict[str, str] cookies: Cookies to be sent with the response. :param str content_type: Content type of response. Defaults to None. """ @@ -74,6 +78,7 @@ def __init__( # pylint: disable=too-many-arguments self._headers = ( headers.copy() if isinstance(headers, Headers) else Headers(headers) ) + self._cookies = cookies.copy() if cookies else {} self._content_type = content_type self._size = 0 @@ -94,6 +99,9 @@ def _send_headers( headers.setdefault("Content-Length", content_length) headers.setdefault("Connection", "close") + for cookie_name, cookie_value in self._cookies.items(): + headers.add("Set-Cookie", f"{cookie_name}={cookie_value}") + for header, value in headers.items(): if value is not None: response_message_header += f"{header}: {value}\r\n" @@ -164,6 +172,7 @@ def __init__( # pylint: disable=too-many-arguments *, status: Union[Status, Tuple[int, str]] = OK_200, headers: Union[Headers, Dict[str, str]] = None, + cookies: Dict[str, str] = None, content_type: str = None, as_attachment: bool = False, download_filename: str = None, @@ -176,14 +185,15 @@ def __init__( # pylint: disable=too-many-arguments :param str filename: Name of the file to send. :param str root_path: Path to the root directory from which to serve files. Defaults to server's ``root_path``. - :param Status status: Status code and text. Defaults to 200 OK. + :param Status status: Status code and text. Defaults to ``200 OK``. :param Headers headers: Headers to include in response. + :param Dict[str, str] cookies: Cookies to be sent with the response. :param str content_type: Content type of response. - :param bool as_attachment: If True, the file will be sent as an attachment. + :param bool as_attachment: If ``True``, the file will be sent as an attachment. :param str download_filename: Name of the file to send as an attachment. - :param int buffer_size: Size of the buffer used to send the file. Defaults to 1024. - :param bool head_only: If True, only headers will be sent. Defaults to False. - :param bool safe: If True, checks if ``filename`` is valid. Defaults to True. + :param int buffer_size: Size of the buffer used to send the file. Defaults to ``1024``. + :param bool head_only: If ``True``, only headers will be sent. Defaults to ``False``. + :param bool safe: If ``True``, checks if ``filename`` is valid. Defaults to ``True``. """ if safe: self._verify_file_path_is_valid(filename) @@ -191,6 +201,7 @@ def __init__( # pylint: disable=too-many-arguments super().__init__( request=request, headers=headers, + cookies=cookies, content_type=content_type, status=status, ) @@ -291,6 +302,7 @@ def __init__( # pylint: disable=too-many-arguments *, status: Union[Status, Tuple[int, str]] = OK_200, headers: Union[Headers, Dict[str, str]] = None, + cookies: Dict[str, str] = None, content_type: str = None, ) -> None: """ @@ -298,12 +310,14 @@ def __init__( # pylint: disable=too-many-arguments :param Generator body: Generator that yields chunks of data. :param Status status: Status object or tuple with code and message. :param Headers headers: Headers to be sent with the response. + :param Dict[str, str] cookies: Cookies to be sent with the response. :param str content_type: Content type of the response. """ super().__init__( request=request, headers=headers, + cookies=cookies, status=status, content_type=content_type, ) @@ -350,17 +364,20 @@ def __init__( data: Dict[Any, Any], *, headers: Union[Headers, Dict[str, str]] = None, + cookies: Dict[str, str] = None, status: Union[Status, Tuple[int, str]] = OK_200, ) -> None: """ :param Request request: Request that this is a response to. :param dict data: Data to be sent as JSON. :param Headers headers: Headers to include in response. + :param Dict[str, str] cookies: Cookies to be sent with the response. :param Status status: Status code and text. Defaults to 200 OK. """ super().__init__( request=request, headers=headers, + cookies=cookies, status=status, ) self._data = data @@ -393,19 +410,42 @@ def __init__( url: str, *, permanent: bool = False, + preserve_method: bool = False, + status: Union[Status, Tuple[int, str]] = None, headers: Union[Headers, Dict[str, str]] = None, + cookies: Dict[str, str] = None, ) -> None: """ + By default uses ``permament`` and ``preserve_method`` to determine the ``status`` code to + use, but if you prefer you can specify it directly. + + Note that ``301 Moved Permanently`` and ``302 Found`` can change the method to ``GET`` + while ``307 Temporary Redirect`` and ``308 Permanent Redirect`` preserve the method. + + More information: + https://developer.mozilla.org/en-US/docs/Web/HTTP/Status#redirection_messages + :param Request request: Request that this is a response to. :param str url: URL to redirect to. - :param bool permanent: Whether to use a permanent redirect (308) or a temporary one (307). + :param bool permanent: Whether to use a permanent redirect or a temporary one. + :param bool preserve_method: Whether to preserve the method of the request. + :param Status status: Status object or tuple with code and message. :param Headers headers: Headers to include in response. + :param Dict[str, str] cookies: Cookies to be sent with the response. """ - super().__init__( - request, - status=PERMANENT_REDIRECT_308 if permanent else TEMPORARY_REDIRECT_307, - headers=headers, - ) + + if status is not None and (permanent or preserve_method): + raise ValueError( + "Cannot specify both status and permanent/preserve_method argument" + ) + + if status is None: + if preserve_method: + status = PERMANENT_REDIRECT_308 if permanent else TEMPORARY_REDIRECT_307 + else: + status = MOVED_PERMANENTLY_301 if permanent else FOUND_302 + + super().__init__(request, status=status, headers=headers, cookies=cookies) self._headers.update({"Location": url}) def _send(self) -> None: @@ -451,14 +491,17 @@ def __init__( # pylint: disable=too-many-arguments self, request: Request, headers: Union[Headers, Dict[str, str]] = None, + cookies: Dict[str, str] = None, ) -> None: """ :param Request request: Request object :param Headers headers: Headers to be sent with the response. + :param Dict[str, str] cookies: Cookies to be sent with the response. """ super().__init__( request=request, headers=headers, + cookies=cookies, content_type="text/event-stream", ) self._headers.setdefault("Cache-Control", "no-cache") @@ -558,18 +601,18 @@ def route_func(request: Request): @staticmethod def _check_request_initiates_handshake(request: Request): - if any( + if not all( [ - "websocket" not in request.headers.get("Upgrade", "").lower(), - "upgrade" not in request.headers.get("Connection", "").lower(), - "Sec-WebSocket-Key" not in request.headers, + "websocket" in request.headers.get_directive("Upgrade", "").lower(), + "upgrade" in request.headers.get_directive("Connection", "").lower(), + "Sec-WebSocket-Key" in request.headers, ] ): raise ValueError("Request does not initiate websocket handshake") @staticmethod def _process_sec_websocket_key(request: Request) -> str: - key = request.headers.get("Sec-WebSocket-Key") + key = request.headers.get_directive("Sec-WebSocket-Key") if key is None: raise ValueError("Request does not have Sec-WebSocket-Key header") @@ -583,11 +626,13 @@ def __init__( # pylint: disable=too-many-arguments self, request: Request, headers: Union[Headers, Dict[str, str]] = None, + cookies: Dict[str, str] = None, buffer_size: int = 1024, ) -> None: """ :param Request request: Request object :param Headers headers: Headers to be sent with the response. + :param Dict[str, str] cookies: Cookies to be sent with the response. :param int buffer_size: Size of the buffer used to send and receive messages. """ self._check_request_initiates_handshake(request) @@ -598,6 +643,7 @@ def __init__( # pylint: disable=too-many-arguments request=request, status=SWITCHING_PROTOCOLS_101, headers=headers, + cookies=cookies, ) self._headers.setdefault("Upgrade", "websocket") self._headers.setdefault("Connection", "Upgrade") diff --git a/adafruit_httpserver/route.py b/adafruit_httpserver/route.py index ba3b81e..6916620 100644 --- a/adafruit_httpserver/route.py +++ b/adafruit_httpserver/route.py @@ -8,7 +8,7 @@ """ try: - from typing import Callable, List, Set, Union, Tuple, Dict, TYPE_CHECKING + from typing import Callable, List, Iterable, Union, Tuple, Dict, TYPE_CHECKING if TYPE_CHECKING: from .response import Response @@ -26,7 +26,7 @@ class Route: def __init__( self, path: str = "", - methods: Union[str, Set[str]] = GET, + methods: Union[str, Iterable[str]] = GET, handler: Callable = None, *, append_slash: bool = False, @@ -40,7 +40,7 @@ def __init__( "...", r"[^/]+" ) + ("/?" if append_slash else "") self.methods = ( - set(methods) if isinstance(methods, (set, list)) else set([methods]) + set(methods) if isinstance(methods, (set, list, tuple)) else set([methods]) ) self.handler = handler @@ -118,7 +118,7 @@ def __repr__(self) -> str: def as_route( path: str, - methods: Union[str, Set[str]] = GET, + methods: Union[str, Iterable[str]] = GET, *, append_slash: bool = False, ) -> "Callable[[Callable[..., Response]], Route]": diff --git a/adafruit_httpserver/server.py b/adafruit_httpserver/server.py index 6168da3..40190f6 100644 --- a/adafruit_httpserver/server.py +++ b/adafruit_httpserver/server.py @@ -8,16 +8,17 @@ """ try: - from typing import Callable, Protocol, Union, List, Set, Tuple, Dict + from typing import Callable, Protocol, Union, List, Tuple, Dict, Iterable from socket import socket from socketpool import SocketPool except ImportError: pass from errno import EAGAIN, ECONNRESET, ETIMEDOUT +from time import monotonic, sleep from traceback import print_exception -from .authentication import Basic, Bearer, require_authentication +from .authentication import Basic, Token, Bearer, require_authentication from .exceptions import ( ServerStoppedError, AuthenticationError, @@ -79,7 +80,7 @@ def __init__( def route( self, path: str, - methods: Union[str, Set[str]] = GET, + methods: Union[str, Iterable[str]] = GET, *, append_slash: bool = False, ) -> Callable: @@ -170,7 +171,9 @@ def _verify_can_start(self, host: str, port: int) -> None: except OSError as error: raise RuntimeError(f"Cannot start server on {host}:{port}") from error - def serve_forever(self, host: str, port: int = 80) -> None: + def serve_forever( + self, host: str, port: int = 80, *, poll_interval: float = None + ) -> None: """ Wait for HTTP requests at the given host and port. Does not return. Ignores any exceptions raised by the handler function and continues to serve. @@ -178,6 +181,7 @@ def serve_forever(self, host: str, port: int = 80) -> None: :param str host: host name or IP address :param int port: port + :param float poll_interval: interval between polls in seconds """ self.start(host, port) @@ -190,6 +194,9 @@ def serve_forever(self, host: str, port: int = 80) -> None: except Exception: # pylint: disable=broad-except pass # Ignore exceptions in handler function + if poll_interval is not None: + sleep(poll_interval) + def start(self, host: str, port: int = 80) -> None: """ Start the HTTP server at the given host and port. Requires calling @@ -243,7 +250,7 @@ def _receive_request( request = Request(self, sock, client_address, header_bytes) - content_length = int(request.headers.get("Content-Length", 0)) + content_length = int(request.headers.get_directive("Content-Length", 0)) received_body_bytes = request.body # Receiving remaining body bytes @@ -358,6 +365,8 @@ def poll(self) -> str: conn, client_address = self._sock.accept() conn.settimeout(self._timeout) + _debug_start_time = monotonic() + # Receive the whole request if (request := self._receive_request(conn, client_address)) is None: conn.close() @@ -378,8 +387,10 @@ def poll(self) -> str: # Send the response response._send() # pylint: disable=protected-access + _debug_end_time = monotonic() + if self.debug: - _debug_response_sent(response) + _debug_response_sent(response, _debug_end_time - _debug_start_time) return REQUEST_HANDLED_RESPONSE_SENT @@ -398,7 +409,7 @@ def poll(self) -> str: conn.close() raise error # Raise the exception again to be handled by the user. - def require_authentication(self, auths: List[Union[Basic, Bearer]]) -> None: + def require_authentication(self, auths: List[Union[Basic, Token, Bearer]]) -> None: """ Requires authentication for all routes and files in ``root_path``. Any non-authenticated request will be rejected with a 401 status code. @@ -496,8 +507,8 @@ def _debug_started_server(server: "Server"): print(f"Started development server on http://{host}:{port}") -def _debug_response_sent(response: "Response"): - """Prints a message when after a response is sent.""" +def _debug_response_sent(response: "Response", time_elapsed: float): + """Prints a message after a response is sent.""" # pylint: disable=protected-access client_ip = response._request.client_address[0] method = response._request.method @@ -505,8 +516,11 @@ def _debug_response_sent(response: "Response"): req_size = len(response._request.raw_request) status = response._status res_size = response._size + time_elapsed_ms = f"{round(time_elapsed*1000)}ms" - print(f'{client_ip} -- "{method} {path}" {req_size} -- "{status}" {res_size}') + print( + f'{client_ip} -- "{method} {path}" {req_size} -- "{status}" {res_size} -- {time_elapsed_ms}' + ) def _debug_stopped_server(server: "Server"): # pylint: disable=unused-argument diff --git a/adafruit_httpserver/status.py b/adafruit_httpserver/status.py index 4219d9c..ea72284 100644 --- a/adafruit_httpserver/status.py +++ b/adafruit_httpserver/status.py @@ -43,6 +43,10 @@ def __eq__(self, other: "Status"): PARTIAL_CONTENT_206 = Status(206, "Partial Content") +MOVED_PERMANENTLY_301 = Status(301, "Moved Permanently") + +FOUND_302 = Status(302, "Found") + TEMPORARY_REDIRECT_307 = Status(307, "Temporary Redirect") PERMANENT_REDIRECT_308 = Status(308, "Permanent Redirect") diff --git a/docs/api.rst b/docs/api.rst index 4da3815..8aeda4e 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -19,6 +19,7 @@ .. automodule:: adafruit_httpserver.headers :members: + :inherited-members: .. automodule:: adafruit_httpserver.status :members: diff --git a/docs/examples.rst b/docs/examples.rst index 05e07de..f37551a 100644 --- a/docs/examples.rst +++ b/docs/examples.rst @@ -95,7 +95,8 @@ It is possible to use the MDNS protocol to make the server accessible via a host to an IP address. It is worth noting that it takes a bit longer to get the response from the server when accessing it via the hostname. -In this example, the server is accessible via ``http://custom-mdns-hostname/`` and ``http://custom-mdns-hostname.local/``. +In this example, the server is accessible via the IP and ``http://custom-mdns-hostname.local/``. +On some routers it is also possible to use ``http://custom-mdns-hostname/``, but **this is not guaranteed to work**. .. literalinclude:: ../examples/httpserver_mdns.py :caption: examples/httpserver_mdns.py @@ -169,9 +170,9 @@ It is important to use correct ``enctype``, depending on the type of data you wa - ``application/x-www-form-urlencoded`` - For sending simple text data without any special characters including spaces. If you use it, values will be automatically parsed as strings, but special characters will be URL encoded e.g. ``"Hello World! ^-$%"`` will be saved as ``"Hello+World%21+%5E-%24%25"`` -- ``multipart/form-data`` - For sending text and binary files and/or text data with special characters - When used, values will **not** be automatically parsed as strings, they will stay as bytes instead. - e.g. ``"Hello World! ^-$%"`` will be saved as ``b'Hello World! ^-$%'``, which can be decoded using ``.decode()`` method. +- ``multipart/form-data`` - For sending textwith special characters and files + When used, non-file values will be automatically parsed as strings and non plain text files will be saved as ``bytes``. + e.g. ``"Hello World! ^-$%"`` will be saved as ``'Hello World! ^-$%'``, and e.g. a PNG file will be saved as ``b'\x89PNG\r\n\x1a\n\x00\...``. - ``text/plain`` - For sending text data with special characters. If used, values will be automatically parsed as strings, including special characters, emojis etc. e.g. ``"Hello World! ^-$%"`` will be saved as ``"Hello World! ^-$%"``, this is the **recommended** option. @@ -185,6 +186,20 @@ return only the first one. :emphasize-lines: 32,47,50 :linenos: +Cookies +--------------------- + +You can use cookies to store data on the client side, that will be sent back to the server with every request. +They are often used to store authentication tokens, session IDs, but also user preferences e.g. theme. + +To access cookies, use ``request.cookies`` dictionary. +In order to set cookies, pass ``cookies`` dictionary to ``Response`` constructor or manually add ``Set-Cookie`` header. + +.. literalinclude:: ../examples/httpserver_cookies.py + :caption: examples/httpserver_cookies.py + :emphasize-lines: 70,74-75,82 + :linenos: + Chunked response ---------------- @@ -243,7 +258,7 @@ If you want to apply authentication to the whole server, you need to call ``.req .. literalinclude:: ../examples/httpserver_authentication_server.py :caption: examples/httpserver_authentication_server.py - :emphasize-lines: 8,11-15,19 + :emphasize-lines: 8,11-16,20 :linenos: On the other hand, if you want to apply authentication to a set of routes, you need to call ``require_authentication`` function. @@ -251,7 +266,7 @@ In both cases you can check if ``request`` is authenticated by calling ``check_a .. literalinclude:: ../examples/httpserver_authentication_handlers.py :caption: examples/httpserver_authentication_handlers.py - :emphasize-lines: 9-15,21-25,33,47,59 + :emphasize-lines: 9-16,22-27,35,49,61 :linenos: Redirects @@ -262,10 +277,13 @@ Sometimes you might want to redirect the user to a different URL, either on the You can do that by returning ``Redirect`` from your handler function. You can specify wheter the redirect is permanent or temporary by passing ``permanent=...`` to ``Redirect``. +If you need the redirect to preserve the original request method, you can set ``preserve_method=True``. + +Alternatively, you can pass a ``status`` object directly to ``Redirect`` constructor. .. literalinclude:: ../examples/httpserver_redirects.py :caption: examples/httpserver_redirects.py - :emphasize-lines: 14-18,26,38 + :emphasize-lines: 22-26,32,38,50,62 :linenos: Server-Sent Events @@ -340,19 +358,19 @@ occurs during handling of the request in ``.serve_forever()``. This is how the logs might look like when debug mode is enabled:: Started development server on http://192.168.0.100:80 - 192.168.0.101 -- "GET /" 194 -- "200 OK" 154 - 192.168.0.101 -- "GET /example" 134 -- "404 Not Found" 172 - 192.168.0.102 -- "POST /api" 1241 -- "401 Unauthorized" 95 + 192.168.0.101 -- "GET /" 194 -- "200 OK" 154 -- 96ms + 192.168.0.101 -- "GET /example" 134 -- "404 Not Found" 172 -- 123ms + 192.168.0.102 -- "POST /api" 1241 -- "401 Unauthorized" 95 -- 64ms Traceback (most recent call last): ... File "code.py", line 55, in example_handler KeyError: non_existent_key - 192.168.0.103 -- "GET /index.html" 242 -- "200 OK" 154 + 192.168.0.103 -- "GET /index.html" 242 -- "200 OK" 154 -- 182ms Stopped development server This is the default format of the logs:: - {client_ip} -- "{request_method} {path}" {request_size} -- "{response_status}" {response_size} + {client_ip} -- "{request_method} {path}" {request_size} -- "{response_status}" {response_size} -- {elapsed_ms} If you need more information about the server or request, or you want it in a different format you can modify functions at the bottom of ``adafruit_httpserver/server.py`` that start with ``_debug_...``. diff --git a/examples/httpserver_authentication_handlers.py b/examples/httpserver_authentication_handlers.py index d917ede..d1bae2e 100644 --- a/examples/httpserver_authentication_handlers.py +++ b/examples/httpserver_authentication_handlers.py @@ -9,6 +9,7 @@ from adafruit_httpserver.authentication import ( AuthenticationError, Basic, + Token, Bearer, check_authentication, require_authentication, @@ -21,6 +22,7 @@ # Create a list of available authentication methods. auths = [ Basic("user", "password"), + Token("2db53340-4f9c-4f70-9037-d25bee77eca6"), Bearer("642ec696-2a79-4d60-be3a-7c9a3164d766"), ] diff --git a/examples/httpserver_authentication_server.py b/examples/httpserver_authentication_server.py index 298e28c..8dc5936 100644 --- a/examples/httpserver_authentication_server.py +++ b/examples/httpserver_authentication_server.py @@ -5,12 +5,13 @@ import socketpool import wifi -from adafruit_httpserver import Server, Request, Response, Basic, Bearer +from adafruit_httpserver import Server, Request, Response, Basic, Token, Bearer # Create a list of available authentication methods. auths = [ Basic("user", "password"), + Token("2db53340-4f9c-4f70-9037-d25bee77eca6"), Bearer("642ec696-2a79-4d60-be3a-7c9a3164d766"), ] diff --git a/examples/httpserver_cookies.py b/examples/httpserver_cookies.py new file mode 100644 index 0000000..1f1e91c --- /dev/null +++ b/examples/httpserver_cookies.py @@ -0,0 +1,86 @@ +# SPDX-FileCopyrightText: 2023 Michał Pokusa +# +# SPDX-License-Identifier: Unlicense + +import socketpool +import wifi + +from adafruit_httpserver import Server, Request, Response, GET, Headers + + +pool = socketpool.SocketPool(wifi.radio) +server = Server(pool, debug=True) + + +THEMES = { + "dark": { + "background-color": "#1c1c1c", + "color": "white", + "button-color": "#181818", + }, + "light": { + "background-color": "white", + "color": "#1c1c1c", + "button-color": "white", + }, +} + + +def themed_template(user_preferred_theme: str): + theme = THEMES[user_preferred_theme] + + return f""" + +
++ After changing the theme, close the tab and open again. + Notice that theme stays the same. +
+ + + """ + + +@server.route("/", GET) +def themed_from_cookie(request: Request): + """ + Serve a simple themed page, based on the user's cookie. + """ + + user_theme = request.cookies.get("theme", "light") + wanted_theme = request.query_params.get("theme", user_theme) + + headers = Headers() + headers.add("Set-Cookie", "cookie1=value1") + headers.add("Set-Cookie", "cookie2=value2") + + return Response( + request, + themed_template(wanted_theme), + content_type="text/html", + headers=headers, + cookies={} if user_theme == wanted_theme else {"theme": wanted_theme}, + ) + + +server.serve_forever(str(wifi.radio.ipv4_address)) diff --git a/examples/httpserver_redirects.py b/examples/httpserver_redirects.py index 8b38ca9..1fb7e9e 100644 --- a/examples/httpserver_redirects.py +++ b/examples/httpserver_redirects.py @@ -5,7 +5,15 @@ import socketpool import wifi -from adafruit_httpserver import Server, Request, Response, Redirect, NOT_FOUND_404 +from adafruit_httpserver import ( + Server, + Request, + Response, + Redirect, + POST, + NOT_FOUND_404, + MOVED_PERMANENTLY_301, +) pool = socketpool.SocketPool(wifi.radio) @@ -20,19 +28,35 @@ @server.route("/blinka") def redirect_blinka(request: Request): - """ - Always redirect to a Blinka page as permanent redirect. - """ + """Always redirect to a Blinka page as permanent redirect.""" return Redirect(request, "https://circuitpython.org/blinka", permanent=True) +@server.route("/adafruit") +def redirect_adafruit(request: Request): + """Permanent redirect to Adafruit website with explicitly set status code.""" + return Redirect(request, "https://www.adafruit.com/", status=MOVED_PERMANENTLY_301) + + +@server.route("/fake-login", POST) +def fake_login(request: Request): + """Fake login page.""" + return Response(request, "Fake login page with POST data preserved.") + + +@server.route("/login", POST) +def temporary_login_redirect(request: Request): + """Temporary moved login page with preserved POST data.""" + return Redirect(request, "/fake-login", preserve_method=True) + + @server.route("/