Skip to content

Commit 4042554

Browse files
authored
Merge pull request #67 from michalpokusa/redirect-301-302-token-form-files-cookies
301/302 Redirects, FormData files, Cookies, Token authentication
2 parents be65668 + d945f6f commit 4042554

15 files changed

+676
-204
lines changed

adafruit_httpserver/__init__.py

+3
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525

2626
from .authentication import (
2727
Basic,
28+
Token,
2829
Bearer,
2930
check_authentication,
3031
require_authentication,
@@ -77,6 +78,8 @@
7778
ACCEPTED_202,
7879
NO_CONTENT_204,
7980
PARTIAL_CONTENT_206,
81+
MOVED_PERMANENTLY_301,
82+
FOUND_302,
8083
TEMPORARY_REDIRECT_307,
8184
PERMANENT_REDIRECT_308,
8285
BAD_REQUEST_400,

adafruit_httpserver/authentication.py

+18-6
Original file line numberDiff line numberDiff line change
@@ -28,17 +28,27 @@ def __str__(self) -> str:
2828
return f"Basic {self._value}"
2929

3030

31-
class Bearer:
32-
"""Represents HTTP Bearer Token Authentication."""
31+
class Token:
32+
"""Represents HTTP Token Authentication."""
33+
34+
prefix = "Token"
3335

3436
def __init__(self, token: str) -> None:
3537
self._value = token
3638

3739
def __str__(self) -> str:
38-
return f"Bearer {self._value}"
40+
return f"{self.prefix} {self._value}"
41+
42+
43+
class Bearer(Token): # pylint: disable=too-few-public-methods
44+
"""Represents HTTP Bearer Token Authentication."""
45+
46+
prefix = "Bearer"
3947

4048

41-
def check_authentication(request: Request, auths: List[Union[Basic, Bearer]]) -> bool:
49+
def check_authentication(
50+
request: Request, auths: List[Union[Basic, Token, Bearer]]
51+
) -> bool:
4252
"""
4353
Returns ``True`` if request is authorized by any of the authentications, ``False`` otherwise.
4454
@@ -47,15 +57,17 @@ def check_authentication(request: Request, auths: List[Union[Basic, Bearer]]) ->
4757
check_authentication(request, [Basic("username", "password")])
4858
"""
4959

50-
auth_header = request.headers.get("Authorization")
60+
auth_header = request.headers.get_directive("Authorization")
5161

5262
if auth_header is None:
5363
return False
5464

5565
return any(auth_header == str(auth) for auth in auths)
5666

5767

58-
def require_authentication(request: Request, auths: List[Union[Basic, Bearer]]) -> None:
68+
def require_authentication(
69+
request: Request, auths: List[Union[Basic, Token, Bearer]]
70+
) -> None:
5971
"""
6072
Checks if the request is authorized and raises ``AuthenticationError`` if not.
6173

adafruit_httpserver/headers.py

+79-34
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,14 @@
88
"""
99

1010
try:
11-
from typing import Dict, Tuple
11+
from typing import Dict, List, Union
1212
except ImportError:
1313
pass
1414

15+
from .interfaces import _IFieldStorage
1516

16-
class Headers:
17+
18+
class Headers(_IFieldStorage):
1719
"""
1820
A dict-like class for storing HTTP headers.
1921
@@ -23,6 +25,8 @@ class Headers:
2325
2426
Examples::
2527
28+
headers = Headers("Content-Type: text/html\\r\\nContent-Length: 1024\\r\\n")
29+
# or
2630
headers = Headers({"Content-Type": "text/html", "Content-Length": "1024"})
2731
2832
len(headers)
@@ -45,60 +49,101 @@ class Headers:
4549
# True
4650
"""
4751

48-
_storage: Dict[str, Tuple[str, str]]
52+
_storage: Dict[str, List[str]]
53+
54+
def __init__(self, headers: Union[str, Dict[str, str]] = None) -> None:
55+
self._storage = {}
4956

50-
def __init__(self, headers: Dict[str, str] = None) -> None:
51-
headers = headers or {}
57+
if isinstance(headers, str):
58+
for header_line in headers.strip().splitlines():
59+
name, value = header_line.split(": ", 1)
60+
self.add(name, value)
61+
else:
62+
for key, value in (headers or {}).items():
63+
self.add(key, value)
5264

53-
self._storage = {key.lower(): [key, value] for key, value in headers.items()}
65+
def add(self, field_name: str, value: str):
66+
"""
67+
Adds a header with the given field name and value.
68+
Allows adding multiple headers with the same name.
69+
"""
70+
self._add_field_value(field_name.lower(), value)
5471

55-
def get(self, name: str, default: str = None):
72+
def get(self, field_name: str, default: str = None) -> Union[str, None]:
5673
"""Returns the value for the given header name, or default if not found."""
57-
return self._storage.get(name.lower(), [None, default])[1]
74+
return super().get(field_name.lower(), default)
5875

59-
def setdefault(self, name: str, default: str = None):
60-
"""Sets the value for the given header name if it does not exist."""
61-
return self._storage.setdefault(name.lower(), [name, default])[1]
76+
def get_list(self, field_name: str) -> List[str]:
77+
"""Get the list of values of a field."""
78+
return super().get_list(field_name.lower())
79+
80+
def get_directive(self, name: str, default: str = None) -> Union[str, None]:
81+
"""
82+
Returns the main value (directive) for the given header name, or default if not found.
83+
84+
Example::
85+
86+
headers = Headers({"Content-Type": "text/html; charset=utf-8"})
87+
headers.get_directive("Content-Type")
88+
# 'text/html'
89+
"""
90+
91+
header_value = self.get(name)
92+
if header_value is None:
93+
return default
94+
return header_value.split(";")[0].strip('" ')
95+
96+
def get_parameter(
97+
self, name: str, parameter: str, default: str = None
98+
) -> Union[str, None]:
99+
"""
100+
Returns the value of the given parameter for the given header name, or default if not found.
62101
63-
def items(self):
64-
"""Returns a list of (name, value) tuples."""
65-
return dict(self._storage.values()).items()
102+
Example::
66103
67-
def keys(self):
68-
"""Returns a list of header names."""
69-
return dict(self._storage.values()).keys()
104+
headers = Headers({"Content-Type": "text/html; charset=utf-8"})
105+
headers.get_parameter("Content-Type", "charset")
106+
# 'utf-8'
107+
"""
70108

71-
def values(self):
72-
"""Returns a list of header values."""
73-
return dict(self._storage.values()).values()
109+
header_value = self.get(name)
110+
if header_value is None:
111+
return default
112+
for header_parameter in header_value.split(";"):
113+
if header_parameter.strip().startswith(parameter):
114+
return header_parameter.strip().split("=")[1].strip('" ')
115+
return default
116+
117+
def set(self, name: str, value: str):
118+
"""Sets the value for the given header name."""
119+
self._storage[name.lower()] = [value]
120+
121+
def setdefault(self, name: str, default: str = None):
122+
"""Sets the value for the given header name if it does not exist."""
123+
return self._storage.setdefault(name.lower(), [default])
74124

75125
def update(self, headers: Dict[str, str]):
76126
"""Updates the headers with the given dict."""
77127
return self._storage.update(
78-
{key.lower(): [key, value] for key, value in headers.items()}
128+
{key.lower(): [value] for key, value in headers.items()}
79129
)
80130

81131
def copy(self):
82132
"""Returns a copy of the headers."""
83-
return Headers(dict(self._storage.values()))
133+
return Headers(
134+
"\r\n".join(
135+
f"{key}: {value}" for key in self.fields for value in self.get_list(key)
136+
)
137+
)
84138

85139
def __getitem__(self, name: str):
86-
return self._storage[name.lower()][1]
140+
return super().__getitem__(name.lower())
87141

88142
def __setitem__(self, name: str, value: str):
89-
self._storage[name.lower()] = [name, value]
143+
self._storage[name.lower()] = [value]
90144

91145
def __delitem__(self, name: str):
92146
del self._storage[name.lower()]
93147

94-
def __iter__(self):
95-
return iter(dict(self._storage.values()))
96-
97-
def __len__(self):
98-
return len(self._storage)
99-
100148
def __contains__(self, key: str):
101-
return key.lower() in self._storage.keys()
102-
103-
def __repr__(self):
104-
return f"{self.__class__.__name__}({dict(self._storage.values())})"
149+
return super().__contains__(key.lower())

adafruit_httpserver/interfaces.py

+109
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2023 Michał Pokusa
2+
#
3+
# SPDX-License-Identifier: MIT
4+
"""
5+
`adafruit_httpserver.interfaces`
6+
====================================================
7+
* Author(s): Michał Pokusa
8+
"""
9+
10+
try:
11+
from typing import List, Dict, Union, Any
12+
except ImportError:
13+
pass
14+
15+
16+
class _IFieldStorage:
17+
"""Interface with shared methods for QueryParams, FormData and Headers."""
18+
19+
_storage: Dict[str, List[Any]]
20+
21+
def _add_field_value(self, field_name: str, value: Any) -> None:
22+
if field_name not in self._storage:
23+
self._storage[field_name] = [value]
24+
else:
25+
self._storage[field_name].append(value)
26+
27+
def get(self, field_name: str, default: Any = None) -> Union[Any, None]:
28+
"""Get the value of a field."""
29+
return self._storage.get(field_name, [default])[0]
30+
31+
def get_list(self, field_name: str) -> List[Any]:
32+
"""Get the list of values of a field."""
33+
return self._storage.get(field_name, [])
34+
35+
@property
36+
def fields(self):
37+
"""Returns a list of field names."""
38+
return list(self._storage.keys())
39+
40+
def items(self):
41+
"""Returns a list of (name, value) tuples."""
42+
return [(key, value) for key in self.fields for value in self.get_list(key)]
43+
44+
def keys(self):
45+
"""Returns a list of header names."""
46+
return self.fields
47+
48+
def values(self):
49+
"""Returns a list of header values."""
50+
return [value for key in self.keys() for value in self.get_list(key)]
51+
52+
def __getitem__(self, field_name: str):
53+
return self._storage[field_name][0]
54+
55+
def __iter__(self):
56+
return iter(self._storage)
57+
58+
def __len__(self) -> int:
59+
return len(self._storage)
60+
61+
def __contains__(self, key: str) -> bool:
62+
return key in self._storage
63+
64+
def __repr__(self) -> str:
65+
return f"{self.__class__.__name__}({repr(self._storage)})"
66+
67+
68+
def _encode_html_entities(value: Union[str, None]) -> Union[str, None]:
69+
"""Encodes unsafe HTML characters that could enable XSS attacks."""
70+
if value is None:
71+
return None
72+
73+
return (
74+
str(value)
75+
.replace("&", "&")
76+
.replace("<", "&lt;")
77+
.replace(">", "&gt;")
78+
.replace('"', "&quot;")
79+
.replace("'", "&apos;")
80+
)
81+
82+
83+
class _IXSSSafeFieldStorage(_IFieldStorage):
84+
def get(
85+
self, field_name: str, default: Any = None, *, safe=True
86+
) -> Union[Any, None]:
87+
if safe:
88+
return _encode_html_entities(super().get(field_name, default))
89+
90+
_debug_warning_nonencoded_output()
91+
return super().get(field_name, default)
92+
93+
def get_list(self, field_name: str, *, safe=True) -> List[Any]:
94+
if safe:
95+
return [
96+
_encode_html_entities(value) for value in super().get_list(field_name)
97+
]
98+
99+
_debug_warning_nonencoded_output()
100+
return super().get_list(field_name)
101+
102+
103+
def _debug_warning_nonencoded_output():
104+
"""Warns about XSS risks."""
105+
print(
106+
"WARNING: Setting safe to False makes XSS vulnerabilities possible by "
107+
"allowing access to raw untrusted values submitted by users. If this data is reflected "
108+
"or shown within HTML without proper encoding it could enable Cross-Site Scripting."
109+
)

0 commit comments

Comments
 (0)