Skip to content

Commit 6031bec

Browse files
committed
Added Headers.get_directive() and Headers.get_parameter()
@to get_parameter
1 parent d743d1c commit 6031bec

File tree

3 files changed

+44
-7
lines changed

3 files changed

+44
-7
lines changed

adafruit_httpserver/headers.py

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
"""
99

1010
try:
11-
from typing import Dict, Tuple
11+
from typing import Dict, Tuple, Union
1212
except ImportError:
1313
pass
1414

@@ -56,6 +56,43 @@ def get(self, name: str, default: str = None) -> Union[str, None]:
5656
"""Returns the value for the given header name, or default if not found."""
5757
return self._storage.get(name.lower(), [None, default])[1]
5858

59+
def get_directive(self, name: str, default: str = None) -> Union[str, None]:
60+
"""
61+
Returns the main value (directive) for the given header name, or default if not found.
62+
63+
Example::
64+
65+
headers = Headers({"Content-Type": "text/html; charset=utf-8"})
66+
headers.get_directive("Content-Type")
67+
# 'text/html'
68+
"""
69+
70+
header_value = self.get(name)
71+
if header_value is None:
72+
return default
73+
return header_value.split(";")[0].strip('" ')
74+
75+
def get_parameter(
76+
self, name: str, parameter: str, default: str = None
77+
) -> Union[str, None]:
78+
"""
79+
Returns the value of the given parameter for the given header name, or default if not found.
80+
81+
Example::
82+
83+
headers = Headers({"Content-Type": "text/html; charset=utf-8"})
84+
headers.get_parameter("Content-Type", "charset")
85+
# 'utf-8'
86+
"""
87+
88+
header_value = self.get(name)
89+
if header_value is None:
90+
return default
91+
for header_parameter in header_value.split(";"):
92+
if header_parameter.strip().startswith(parameter):
93+
return header_parameter.strip().split("=")[1].strip('" ')
94+
return default
95+
5996
def setdefault(self, name: str, default: str = None):
6097
"""Sets the value for the given header name if it does not exist."""
6198
return self._storage.setdefault(name.lower(), [name, default])[1]

adafruit_httpserver/response.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -581,18 +581,18 @@ def route_func(request: Request):
581581

582582
@staticmethod
583583
def _check_request_initiates_handshake(request: Request):
584-
if any(
584+
if not all(
585585
[
586-
"websocket" not in request.headers.get("Upgrade", "").lower(),
587-
"upgrade" not in request.headers.get("Connection", "").lower(),
588-
"Sec-WebSocket-Key" not in request.headers,
586+
"websocket" in request.headers.get_directive("Upgrade", "").lower(),
587+
"upgrade" in request.headers.get_directive("Connection", "").lower(),
588+
"Sec-WebSocket-Key" in request.headers,
589589
]
590590
):
591591
raise ValueError("Request does not initiate websocket handshake")
592592

593593
@staticmethod
594594
def _process_sec_websocket_key(request: Request) -> str:
595-
key = request.headers.get("Sec-WebSocket-Key")
595+
key = request.headers.get_directive("Sec-WebSocket-Key")
596596

597597
if key is None:
598598
raise ValueError("Request does not have Sec-WebSocket-Key header")

adafruit_httpserver/server.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -243,7 +243,7 @@ def _receive_request(
243243

244244
request = Request(self, sock, client_address, header_bytes)
245245

246-
content_length = int(request.headers.get("Content-Length", 0))
246+
content_length = int(request.headers.get_directive("Content-Length", 0))
247247
received_body_bytes = request.body
248248

249249
# Receiving remaining body bytes

0 commit comments

Comments
 (0)