20
20
import json
21
21
22
22
from .headers import Headers
23
+ from .interfaces import _IFieldStorage , _IXSSSafeFieldStorage
23
24
24
25
25
- class _IFieldStorage :
26
- """Interface with shared methods for QueryParams and FormData."""
27
-
28
- _storage : Dict [str , List [Union [str , bytes ]]]
29
-
30
- def _add_field_value (self , field_name : str , value : Union [str , bytes ]) -> None :
31
- if field_name not in self ._storage :
32
- self ._storage [field_name ] = [value ]
33
- else :
34
- self ._storage [field_name ].append (value )
35
-
36
- @staticmethod
37
- def _encode_html_entities (value : str ) -> str :
38
- """Encodes unsafe HTML characters."""
39
- return (
40
- str (value )
41
- .replace ("&" , "&" )
42
- .replace ("<" , "<" )
43
- .replace (">" , ">" )
44
- .replace ('"' , """ )
45
- .replace ("'" , "'" )
46
- )
47
-
48
- def get (
49
- self , field_name : str , default : Any = None , * , safe = True
50
- ) -> Union [str , bytes , None ]:
51
- """Get the value of a field."""
52
- if safe :
53
- return self ._encode_html_entities (
54
- self ._storage .get (field_name , [default ])[0 ]
55
- )
56
-
57
- _debug_warning_nonencoded_output ()
58
- return self ._storage .get (field_name , [default ])[0 ]
59
-
60
- def get_list (self , field_name : str ) -> List [Union [str , bytes ]]:
61
- """Get the list of values of a field."""
62
- return self ._storage .get (field_name , [])
63
-
64
- @property
65
- def fields (self ):
66
- """Returns a list of field names."""
67
- return list (self ._storage .keys ())
68
-
69
- def __getitem__ (self , field_name : str ):
70
- return self .get (field_name )
71
-
72
- def __iter__ (self ):
73
- return iter (self ._storage )
74
-
75
- def __len__ (self ):
76
- return len (self ._storage )
77
-
78
- def __contains__ (self , key : str ):
79
- return key in self ._storage
80
-
81
- def __repr__ (self ) -> str :
82
- return f"{ self .__class__ .__name__ } ({ repr (self ._storage )} )"
83
-
84
-
85
- class QueryParams (_IFieldStorage ):
26
+ class QueryParams (_IXSSSafeFieldStorage ):
86
27
"""
87
28
Class for parsing and storing GET query parameters requests.
88
29
89
30
Examples::
90
31
91
32
query_params = QueryParams("foo=bar&baz=qux&baz=quux")
92
- # QueryParams({"foo": "bar", "baz": ["qux", "quux"]})
33
+ # QueryParams({"foo": [ "bar"] , "baz": ["qux", "quux"]})
93
34
94
35
query_params.get("foo") # "bar"
95
36
query_params["foo"] # "bar"
@@ -111,8 +52,80 @@ def __init__(self, query_string: str) -> None:
111
52
elif query_param :
112
53
self ._add_field_value (query_param , "" )
113
54
55
+ def _add_field_value (self , field_name : str , value : str ) -> None :
56
+ super ()._add_field_value (field_name , value )
57
+
58
+ def get (
59
+ self , field_name : str , default : str = None , * , safe = True
60
+ ) -> Union [str , None ]:
61
+ return super ().get (field_name , default , safe = safe )
62
+
63
+ def get_list (self , field_name : str , * , safe = True ) -> List [str ]:
64
+ return super ().get_list (field_name , safe = safe )
65
+
114
66
115
- class FormData (_IFieldStorage ):
67
+ class File :
68
+ """
69
+ Class representing a file uploaded via POST.
70
+
71
+ Examples::
72
+
73
+ file = request.form_data.files.get("uploaded_file")
74
+ # File(filename="foo.txt", content_type="text/plain", size=14)
75
+
76
+ file.content
77
+ # "Hello, world!\\ n"
78
+ """
79
+
80
+ filename : str
81
+ """Filename of the file."""
82
+
83
+ content_type : str
84
+ """Content type of the file."""
85
+
86
+ content : Union [str , bytes ]
87
+ """Content of the file."""
88
+
89
+ def __init__ (
90
+ self , filename : str , content_type : str , content : Union [str , bytes ]
91
+ ) -> None :
92
+ self .filename = filename
93
+ self .content_type = content_type
94
+ self .content = content
95
+
96
+ @property
97
+ def size (self ) -> int :
98
+ """Length of the file content."""
99
+ return len (self .content )
100
+
101
+ def __repr__ (self ) -> str :
102
+ filename , content_type , size = (
103
+ repr (self .filename ),
104
+ repr (self .content_type ),
105
+ repr (self .size ),
106
+ )
107
+ return f"{ self .__class__ .__name__ } ({ filename = } , { content_type = } , { size = } )"
108
+
109
+
110
+ class Files (_IFieldStorage ):
111
+ """Class for files uploaded via POST."""
112
+
113
+ _storage : Dict [str , List [File ]]
114
+
115
+ def __init__ (self ) -> None :
116
+ self ._storage = {}
117
+
118
+ def _add_field_value (self , field_name : str , value : File ) -> None :
119
+ super ()._add_field_value (field_name , value )
120
+
121
+ def get (self , field_name : str , default : Any = None ) -> Union [File , Any , None ]:
122
+ return super ().get (field_name , default )
123
+
124
+ def get_list (self , field_name : str ) -> List [File ]:
125
+ return super ().get_list (field_name )
126
+
127
+
128
+ class FormData (_IXSSSafeFieldStorage ):
116
129
"""
117
130
Class for parsing and storing form data from POST requests.
118
131
@@ -124,7 +137,7 @@ class FormData(_IFieldStorage):
124
137
form_data = FormData(b"foo=bar&baz=qux&baz=quuz", "application/x-www-form-urlencoded")
125
138
# or
126
139
form_data = FormData(b"foo=bar\\ r\\ nbaz=qux\\ r\\ nbaz=quux", "text/plain")
127
- # FormData({"foo": "bar", "baz": "qux"})
140
+ # FormData({"foo": [ "bar"] , "baz": [ "qux", "quux"] })
128
141
129
142
form_data.get("foo") # "bar"
130
143
form_data["foo"] # "bar"
@@ -135,10 +148,12 @@ class FormData(_IFieldStorage):
135
148
"""
136
149
137
150
_storage : Dict [str , List [Union [str , bytes ]]]
151
+ files : Files
138
152
139
153
def __init__ (self , data : bytes , content_type : str ) -> None :
140
154
self .content_type = content_type
141
155
self ._storage = {}
156
+ self .files = Files ()
142
157
143
158
if content_type .startswith ("application/x-www-form-urlencoded" ):
144
159
self ._parse_x_www_form_urlencoded (data )
@@ -162,11 +177,25 @@ def _parse_multipart_form_data(self, data: bytes, boundary: str) -> None:
162
177
blocks = data .split (b"--" + boundary .encode ())[1 :- 1 ]
163
178
164
179
for block in blocks :
165
- disposition , content = block .split (b"\r \n \r \n " , 1 )
166
- field_name = disposition .split (b'"' , 2 )[1 ].decode ()
167
- value = content [:- 2 ]
180
+ header_bytes , content_bytes = block .split (b"\r \n \r \n " , 1 )
181
+ headers = Headers (header_bytes .decode ("utf-8" ).strip ())
168
182
169
- self ._add_field_value (field_name , value )
183
+ field_name = headers .get_parameter ("Content-Disposition" , "name" )
184
+ filename = headers .get_parameter ("Content-Disposition" , "filename" )
185
+ content_type = headers .get_directive ("Content-Type" , "text/plain" )
186
+ charset = headers .get_parameter ("Content-Type" , "charset" , "utf-8" )
187
+
188
+ content = content_bytes [:- 2 ] # remove trailing \r\n
189
+ value = content .decode (charset ) if content_type == "text/plain" else content
190
+
191
+ # TODO: Other text content types (e.g. application/json) should be decoded as well and
192
+
193
+ if filename is not None :
194
+ self .files ._add_field_value ( # pylint: disable=protected-access
195
+ field_name , File (filename , content_type , value )
196
+ )
197
+ else :
198
+ self ._add_field_value (field_name , value )
170
199
171
200
def _parse_text_plain (self , data : bytes ) -> None :
172
201
lines = data .decode ("utf-8" ).split ("\r \n " )[:- 1 ]
@@ -176,6 +205,21 @@ def _parse_text_plain(self, data: bytes) -> None:
176
205
177
206
self ._add_field_value (field_name , value )
178
207
208
+ def _add_field_value (self , field_name : str , value : Union [str , bytes ]) -> None :
209
+ super ()._add_field_value (field_name , value )
210
+
211
+ def get (
212
+ self , field_name : str , default : Union [str , bytes ] = None , * , safe = True
213
+ ) -> Union [str , bytes , None ]:
214
+ return super ().get (field_name , default , safe = safe )
215
+
216
+ def get_list (self , field_name : str , * , safe = True ) -> List [Union [str , bytes ]]:
217
+ return super ().get_list (field_name , safe = safe )
218
+
219
+ def __repr__ (self ) -> str :
220
+ class_name = self .__class__ .__name__
221
+ return f"{ class_name } ({ repr (self ._storage )} , files={ repr (self .files ._storage )} )"
222
+
179
223
180
224
class Request :
181
225
"""
@@ -358,12 +402,3 @@ def _parse_request_header(
358
402
headers = Headers (headers_string )
359
403
360
404
return method , path , query_params , http_version , headers
361
-
362
-
363
- def _debug_warning_nonencoded_output ():
364
- """Warns about XSS risks."""
365
- print (
366
- "WARNING: Setting safe to False makes XSS vulnerabilities possible by "
367
- "allowing access to raw untrusted values submitted by users. If this data is reflected "
368
- "or shown within HTML without proper encoding it could enable Cross-Site Scripting."
369
- )
0 commit comments