Skip to content

Commit 9a79039

Browse files
committed
Refactor of QueryParams and FormData, moved interfaces to separate file
1 parent 5da48c2 commit 9a79039

File tree

3 files changed

+209
-80
lines changed

3 files changed

+209
-80
lines changed

adafruit_httpserver/interfaces.py

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2023 Michał Pokusa
2+
#
3+
# SPDX-License-Identifier: MIT
4+
"""
5+
`adafruit_httpserver.interfaces`
6+
====================================================
7+
* Author(s): Michał Pokusa
8+
"""
9+
10+
try:
11+
from typing import List, Dict, Union, Any
12+
except ImportError:
13+
pass
14+
15+
16+
class _IFieldStorage:
17+
"""Interface with shared methods for QueryParams and FormData."""
18+
19+
_storage: Dict[str, List[Any]]
20+
21+
def _add_field_value(self, field_name: str, value: Any) -> None:
22+
if field_name not in self._storage:
23+
self._storage[field_name] = [value]
24+
else:
25+
self._storage[field_name].append(value)
26+
27+
def get(self, field_name: str, default: Any = None) -> Union[Any, None]:
28+
"""Get the value of a field."""
29+
return self._storage.get(field_name, [default])[0]
30+
31+
def get_list(self, field_name: str) -> List[Any]:
32+
"""Get the list of values of a field."""
33+
return self._storage.get(field_name, [])
34+
35+
@property
36+
def fields(self):
37+
"""Returns a list of field names."""
38+
return list(self._storage.keys())
39+
40+
def __getitem__(self, field_name: str):
41+
return self._storage[field_name][0]
42+
43+
def __iter__(self):
44+
return iter(self._storage)
45+
46+
def __len__(self) -> int:
47+
return len(self._storage)
48+
49+
def __contains__(self, key: str) -> bool:
50+
return key in self._storage
51+
52+
def __repr__(self) -> str:
53+
return f"{self.__class__.__name__}({repr(self._storage)})"
54+
55+
56+
def _encode_html_entities(value: str) -> str:
57+
"""Encodes unsafe HTML characters that could enable XSS attacks."""
58+
return (
59+
str(value)
60+
.replace("&", "&")
61+
.replace("<", "&lt;")
62+
.replace(">", "&gt;")
63+
.replace('"', "&quot;")
64+
.replace("'", "&apos;")
65+
)
66+
67+
68+
class _IXSSSafeFieldStorage(_IFieldStorage):
69+
def get(
70+
self, field_name: str, default: Any = None, *, safe=True
71+
) -> Union[Any, None]:
72+
if safe:
73+
return _encode_html_entities(super().get(field_name, default))
74+
75+
_debug_warning_nonencoded_output()
76+
return super().get(field_name, default)
77+
78+
def get_list(self, field_name: str, *, safe=True) -> List[Any]:
79+
if safe:
80+
return [
81+
_encode_html_entities(value) for value in super().get_list(field_name)
82+
]
83+
84+
_debug_warning_nonencoded_output()
85+
return super().get_list(field_name)
86+
87+
88+
def _debug_warning_nonencoded_output():
89+
"""Warns about XSS risks."""
90+
print(
91+
"WARNING: Setting safe to False makes XSS vulnerabilities possible by "
92+
"allowing access to raw untrusted values submitted by users. If this data is reflected "
93+
"or shown within HTML without proper encoding it could enable Cross-Site Scripting."
94+
)

adafruit_httpserver/request.py

Lines changed: 112 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -20,76 +20,17 @@
2020
import json
2121

2222
from .headers import Headers
23+
from .interfaces import _IFieldStorage, _IXSSSafeFieldStorage
2324

2425

25-
class _IFieldStorage:
26-
"""Interface with shared methods for QueryParams and FormData."""
27-
28-
_storage: Dict[str, List[Union[str, bytes]]]
29-
30-
def _add_field_value(self, field_name: str, value: Union[str, bytes]) -> None:
31-
if field_name not in self._storage:
32-
self._storage[field_name] = [value]
33-
else:
34-
self._storage[field_name].append(value)
35-
36-
@staticmethod
37-
def _encode_html_entities(value: str) -> str:
38-
"""Encodes unsafe HTML characters."""
39-
return (
40-
str(value)
41-
.replace("&", "&amp;")
42-
.replace("<", "&lt;")
43-
.replace(">", "&gt;")
44-
.replace('"', "&quot;")
45-
.replace("'", "&apos;")
46-
)
47-
48-
def get(
49-
self, field_name: str, default: Any = None, *, safe=True
50-
) -> Union[str, bytes, None]:
51-
"""Get the value of a field."""
52-
if safe:
53-
return self._encode_html_entities(
54-
self._storage.get(field_name, [default])[0]
55-
)
56-
57-
_debug_warning_nonencoded_output()
58-
return self._storage.get(field_name, [default])[0]
59-
60-
def get_list(self, field_name: str) -> List[Union[str, bytes]]:
61-
"""Get the list of values of a field."""
62-
return self._storage.get(field_name, [])
63-
64-
@property
65-
def fields(self):
66-
"""Returns a list of field names."""
67-
return list(self._storage.keys())
68-
69-
def __getitem__(self, field_name: str):
70-
return self.get(field_name)
71-
72-
def __iter__(self):
73-
return iter(self._storage)
74-
75-
def __len__(self):
76-
return len(self._storage)
77-
78-
def __contains__(self, key: str):
79-
return key in self._storage
80-
81-
def __repr__(self) -> str:
82-
return f"{self.__class__.__name__}({repr(self._storage)})"
83-
84-
85-
class QueryParams(_IFieldStorage):
26+
class QueryParams(_IXSSSafeFieldStorage):
8627
"""
8728
Class for parsing and storing GET query parameters requests.
8829
8930
Examples::
9031
9132
query_params = QueryParams("foo=bar&baz=qux&baz=quux")
92-
# QueryParams({"foo": "bar", "baz": ["qux", "quux"]})
33+
# QueryParams({"foo": ["bar"], "baz": ["qux", "quux"]})
9334
9435
query_params.get("foo") # "bar"
9536
query_params["foo"] # "bar"
@@ -111,8 +52,80 @@ def __init__(self, query_string: str) -> None:
11152
elif query_param:
11253
self._add_field_value(query_param, "")
11354

55+
def _add_field_value(self, field_name: str, value: str) -> None:
56+
super()._add_field_value(field_name, value)
57+
58+
def get(
59+
self, field_name: str, default: str = None, *, safe=True
60+
) -> Union[str, None]:
61+
return super().get(field_name, default, safe=safe)
62+
63+
def get_list(self, field_name: str, *, safe=True) -> List[str]:
64+
return super().get_list(field_name, safe=safe)
65+
11466

115-
class FormData(_IFieldStorage):
67+
class File:
68+
"""
69+
Class representing a file uploaded via POST.
70+
71+
Examples::
72+
73+
file = request.form_data.files.get("uploaded_file")
74+
# File(filename="foo.txt", content_type="text/plain", size=14)
75+
76+
file.content
77+
# "Hello, world!\\n"
78+
"""
79+
80+
filename: str
81+
"""Filename of the file."""
82+
83+
content_type: str
84+
"""Content type of the file."""
85+
86+
content: Union[str, bytes]
87+
"""Content of the file."""
88+
89+
def __init__(
90+
self, filename: str, content_type: str, content: Union[str, bytes]
91+
) -> None:
92+
self.filename = filename
93+
self.content_type = content_type
94+
self.content = content
95+
96+
@property
97+
def size(self) -> int:
98+
"""Length of the file content."""
99+
return len(self.content)
100+
101+
def __repr__(self) -> str:
102+
filename, content_type, size = (
103+
repr(self.filename),
104+
repr(self.content_type),
105+
repr(self.size),
106+
)
107+
return f"{self.__class__.__name__}({filename=}, {content_type=}, {size=})"
108+
109+
110+
class Files(_IFieldStorage):
111+
"""Class for files uploaded via POST."""
112+
113+
_storage: Dict[str, List[File]]
114+
115+
def __init__(self) -> None:
116+
self._storage = {}
117+
118+
def _add_field_value(self, field_name: str, value: File) -> None:
119+
super()._add_field_value(field_name, value)
120+
121+
def get(self, field_name: str, default: Any = None) -> Union[File, Any, None]:
122+
return super().get(field_name, default)
123+
124+
def get_list(self, field_name: str) -> List[File]:
125+
return super().get_list(field_name)
126+
127+
128+
class FormData(_IXSSSafeFieldStorage):
116129
"""
117130
Class for parsing and storing form data from POST requests.
118131
@@ -124,7 +137,7 @@ class FormData(_IFieldStorage):
124137
form_data = FormData(b"foo=bar&baz=qux&baz=quuz", "application/x-www-form-urlencoded")
125138
# or
126139
form_data = FormData(b"foo=bar\\r\\nbaz=qux\\r\\nbaz=quux", "text/plain")
127-
# FormData({"foo": "bar", "baz": "qux"})
140+
# FormData({"foo": ["bar"], "baz": ["qux", "quux"]})
128141
129142
form_data.get("foo") # "bar"
130143
form_data["foo"] # "bar"
@@ -135,10 +148,12 @@ class FormData(_IFieldStorage):
135148
"""
136149

137150
_storage: Dict[str, List[Union[str, bytes]]]
151+
files: Files
138152

139153
def __init__(self, data: bytes, content_type: str) -> None:
140154
self.content_type = content_type
141155
self._storage = {}
156+
self.files = Files()
142157

143158
if content_type.startswith("application/x-www-form-urlencoded"):
144159
self._parse_x_www_form_urlencoded(data)
@@ -162,11 +177,25 @@ def _parse_multipart_form_data(self, data: bytes, boundary: str) -> None:
162177
blocks = data.split(b"--" + boundary.encode())[1:-1]
163178

164179
for block in blocks:
165-
disposition, content = block.split(b"\r\n\r\n", 1)
166-
field_name = disposition.split(b'"', 2)[1].decode()
167-
value = content[:-2]
180+
header_bytes, content_bytes = block.split(b"\r\n\r\n", 1)
181+
headers = Headers(header_bytes.decode("utf-8").strip())
168182

169-
self._add_field_value(field_name, value)
183+
field_name = headers.get_parameter("Content-Disposition", "name")
184+
filename = headers.get_parameter("Content-Disposition", "filename")
185+
content_type = headers.get_directive("Content-Type", "text/plain")
186+
charset = headers.get_parameter("Content-Type", "charset", "utf-8")
187+
188+
content = content_bytes[:-2] # remove trailing \r\n
189+
value = content.decode(charset) if content_type == "text/plain" else content
190+
191+
# TODO: Other text content types (e.g. application/json) should be decoded as well and
192+
193+
if filename is not None:
194+
self.files._add_field_value( # pylint: disable=protected-access
195+
field_name, File(filename, content_type, value)
196+
)
197+
else:
198+
self._add_field_value(field_name, value)
170199

171200
def _parse_text_plain(self, data: bytes) -> None:
172201
lines = data.decode("utf-8").split("\r\n")[:-1]
@@ -176,6 +205,21 @@ def _parse_text_plain(self, data: bytes) -> None:
176205

177206
self._add_field_value(field_name, value)
178207

208+
def _add_field_value(self, field_name: str, value: Union[str, bytes]) -> None:
209+
super()._add_field_value(field_name, value)
210+
211+
def get(
212+
self, field_name: str, default: Union[str, bytes] = None, *, safe=True
213+
) -> Union[str, bytes, None]:
214+
return super().get(field_name, default, safe=safe)
215+
216+
def get_list(self, field_name: str, *, safe=True) -> List[Union[str, bytes]]:
217+
return super().get_list(field_name, safe=safe)
218+
219+
def __repr__(self) -> str:
220+
class_name = self.__class__.__name__
221+
return f"{class_name}({repr(self._storage)}, files={repr(self.files._storage)})"
222+
179223

180224
class Request:
181225
"""
@@ -358,12 +402,3 @@ def _parse_request_header(
358402
headers = Headers(headers_string)
359403

360404
return method, path, query_params, http_version, headers
361-
362-
363-
def _debug_warning_nonencoded_output():
364-
"""Warns about XSS risks."""
365-
print(
366-
"WARNING: Setting safe to False makes XSS vulnerabilities possible by "
367-
"allowing access to raw untrusted values submitted by users. If this data is reflected "
368-
"or shown within HTML without proper encoding it could enable Cross-Site Scripting."
369-
)

docs/examples.rst

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -169,9 +169,9 @@ It is important to use correct ``enctype``, depending on the type of data you wa
169169
- ``application/x-www-form-urlencoded`` - For sending simple text data without any special characters including spaces.
170170
If you use it, values will be automatically parsed as strings, but special characters will be URL encoded
171171
e.g. ``"Hello World! ^-$%"`` will be saved as ``"Hello+World%21+%5E-%24%25"``
172-
- ``multipart/form-data`` - For sending text and binary files and/or text data with special characters
173-
When used, values will **not** be automatically parsed as strings, they will stay as bytes instead.
174-
e.g. ``"Hello World! ^-$%"`` will be saved as ``b'Hello World! ^-$%'``, which can be decoded using ``.decode()`` method.
172+
- ``multipart/form-data`` - For sending textwith special characters and files
173+
When used, non-file values will be automatically parsed as strings and non plain text files will be saved as ``bytes``.
174+
e.g. ``"Hello World! ^-$%"`` will be saved as ``'Hello World! ^-$%'``, and e.g. a PNG file will be saved as ``b'\x89PNG\r\n\x1a\n\x00\...``.
175175
- ``text/plain`` - For sending text data with special characters.
176176
If used, values will be automatically parsed as strings, including special characters, emojis etc.
177177
e.g. ``"Hello World! ^-$%"`` will be saved as ``"Hello World! ^-$%"``, this is the **recommended** option.

0 commit comments

Comments
 (0)