Skip to content

Commit c701ab7

Browse files
authored
Merge pull request #2 from FoamyGuy/files_arg_plus_foamyguy
Files arg plus foamyguy
2 parents ad7aaca + 913c4c8 commit c701ab7

File tree

4 files changed

+110
-74
lines changed

4 files changed

+110
-74
lines changed

adafruit_requests.py

Lines changed: 97 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,8 @@
4646

4747
from adafruit_connection_manager import get_connection_manager
4848

49+
SEEK_END = 2
50+
4951
if not sys.implementation.name == "circuitpython":
5052
from types import TracebackType
5153
from typing import Any, Dict, Optional, Type
@@ -344,14 +346,6 @@ def iter_content(self, chunk_size: int = 1, decode_unicode: bool = False) -> byt
344346
self.close()
345347

346348

347-
def _generate_boundary_str():
348-
hex_characters = "0123456789abcdef"
349-
_boundary = ""
350-
for _ in range(32):
351-
_boundary += random.choice(hex_characters)
352-
return _boundary
353-
354-
355349
class Session:
356350
"""HTTP session that shares sockets and ssl context."""
357351

@@ -366,10 +360,74 @@ def __init__(
366360
self._session_id = session_id
367361
self._last_response = None
368362

363+
def _build_boundary_data(self, files: dict): # pylint: disable=too-many-locals
364+
boundary_string = self._build_boundary_string()
365+
content_length = 0
366+
boundary_objects = []
367+
368+
for field_name, field_values in files.items():
369+
file_name = field_values[0]
370+
file_handle = field_values[1]
371+
372+
boundary_data = f"--{boundary_string}\r\n"
373+
boundary_data += f'Content-Disposition: form-data; name="{field_name}"'
374+
if file_name is not None:
375+
boundary_data += f'; filename="{file_name}"'
376+
boundary_data += "\r\n"
377+
if len(field_values) >= 3:
378+
file_content_type = field_values[2]
379+
boundary_data += f"Content-Type: {file_content_type}\r\n"
380+
if len(field_values) >= 4:
381+
file_headers = field_values[3]
382+
for file_header_key, file_header_value in file_headers.items():
383+
boundary_data += f"{file_header_key}: {file_header_value}\r\n"
384+
boundary_data += "\r\n"
385+
386+
content_length += len(boundary_data)
387+
boundary_objects.append(boundary_data)
388+
389+
if hasattr(file_handle, "read"):
390+
is_binary = False
391+
try:
392+
content = file_handle.read(1)
393+
is_binary = isinstance(content, bytes)
394+
except UnicodeError:
395+
is_binary = False
396+
397+
if not is_binary:
398+
raise AttributeError("Files must be opened in binary mode")
399+
400+
file_handle.seek(0, SEEK_END)
401+
content_length += file_handle.tell()
402+
file_handle.seek(0)
403+
boundary_objects.append(file_handle)
404+
boundary_data = ""
405+
else:
406+
boundary_data = file_handle
407+
408+
boundary_data += "\r\n"
409+
content_length += len(boundary_data)
410+
boundary_objects.append(boundary_data)
411+
412+
boundary_data = f"--{boundary_string}--\r\n"
413+
414+
content_length += len(boundary_data)
415+
boundary_objects.append(boundary_data)
416+
417+
return boundary_string, content_length, boundary_objects
418+
419+
@staticmethod
420+
def _build_boundary_string():
421+
hex_characters = "0123456789abcdef"
422+
_boundary = ""
423+
for _ in range(32):
424+
_boundary += random.choice(hex_characters)
425+
return _boundary
426+
369427
@staticmethod
370428
def _check_headers(headers: Dict[str, str]):
371429
if not isinstance(headers, dict):
372-
raise AttributeError("headers must be in dict format")
430+
raise AttributeError("Headers must be in dict format")
373431

374432
for key, value in headers.items():
375433
if isinstance(value, (str, bytes)) or value is None:
@@ -403,6 +461,19 @@ def _send(socket: SocketType, data: bytes):
403461
def _send_as_bytes(self, socket: SocketType, data: str):
404462
return self._send(socket, bytes(data, "utf-8"))
405463

464+
def _send_boundary_objects(self, socket: SocketType, boundary_objects: Any):
465+
for boundary_object in boundary_objects:
466+
if isinstance(boundary_object, str):
467+
self._send_as_bytes(socket, boundary_object)
468+
else:
469+
chunk_size = 32
470+
b = bytearray(chunk_size)
471+
while True:
472+
size = boundary_object.readinto(b)
473+
if size == 0:
474+
break
475+
self._send(socket, b[:size])
476+
406477
def _send_header(self, socket, header, value):
407478
if value is None:
408479
return
@@ -440,6 +511,7 @@ def _send_request( # pylint: disable=too-many-arguments
440511

441512
# If data is sent and it's a dict, set content type header and convert to string
442513
if data and isinstance(data, dict):
514+
assert files is None
443515
content_type_header = "application/x-www-form-urlencoded"
444516
_post_data = ""
445517
for k in data:
@@ -451,8 +523,18 @@ def _send_request( # pylint: disable=too-many-arguments
451523
if data and isinstance(data, str):
452524
data = bytes(data, "utf-8")
453525

454-
if data is None:
455-
data = b""
526+
# If files are send, build data to send and calculate length
527+
content_length = 0
528+
boundary_objects = None
529+
if files and isinstance(files, dict):
530+
boundary_string, content_length, boundary_objects = (
531+
self._build_boundary_data(files)
532+
)
533+
content_type_header = f"multipart/form-data; boundary={boundary_string}"
534+
else:
535+
if data is None:
536+
data = b""
537+
content_length = len(data)
456538

457539
self._send_as_bytes(socket, method)
458540
self._send(socket, b" /")
@@ -461,60 +543,6 @@ def _send_request( # pylint: disable=too-many-arguments
461543

462544
# create lower-case supplied header list
463545
supplied_headers = {header.lower() for header in headers}
464-
boundary_str = None
465-
466-
# pylint: disable=too-many-nested-blocks
467-
if files is not None and isinstance(files, dict):
468-
boundary_str = _generate_boundary_str()
469-
content_type_header = f"multipart/form-data; boundary={boundary_str}"
470-
471-
for fieldname in files.keys():
472-
if not fieldname.endswith("-name"):
473-
if files[fieldname][0] is not None:
474-
file_content = files[fieldname][1].read()
475-
476-
data += b"--" + boundary_str.encode() + b"\r\n"
477-
data += (
478-
b'Content-Disposition: form-data; name="'
479-
+ fieldname.encode()
480-
+ b'"; filename="'
481-
+ files[fieldname][0].encode()
482-
+ b'"\r\n'
483-
)
484-
if len(files[fieldname]) >= 3:
485-
data += (
486-
b"Content-Type: "
487-
+ files[fieldname][2].encode()
488-
+ b"\r\n"
489-
)
490-
if len(files[fieldname]) >= 4:
491-
for custom_header_key in files[fieldname][3].keys():
492-
data += (
493-
custom_header_key.encode()
494-
+ b": "
495-
+ files[fieldname][3][custom_header_key].encode()
496-
+ b"\r\n"
497-
)
498-
data += b"\r\n"
499-
data += file_content + b"\r\n"
500-
else:
501-
# filename is None
502-
data += b"--" + boundary_str.encode() + b"\r\n"
503-
data += (
504-
b'Content-Disposition: form-data; name="'
505-
+ fieldname.encode()
506-
+ b'"; \r\n'
507-
)
508-
if len(files[fieldname]) >= 3:
509-
data += (
510-
b"Content-Type: "
511-
+ files[fieldname][2].encode()
512-
+ b"\r\n"
513-
)
514-
data += b"\r\n"
515-
data += files[fieldname][1].encode() + b"\r\n"
516-
517-
data += b"--" + boundary_str.encode() + b"--"
518546

519547
# Send headers
520548
if not "host" in supplied_headers:
@@ -523,8 +551,8 @@ def _send_request( # pylint: disable=too-many-arguments
523551
self._send_header(socket, "User-Agent", "Adafruit CircuitPython")
524552
if content_type_header and not "content-type" in supplied_headers:
525553
self._send_header(socket, "Content-Type", content_type_header)
526-
if data and not "content-length" in supplied_headers:
527-
self._send_header(socket, "Content-Length", str(len(data)))
554+
if (data or files) and not "content-length" in supplied_headers:
555+
self._send_header(socket, "Content-Length", str(content_length))
528556
# Iterate over keys to avoid tuple alloc
529557
for header in headers:
530558
self._send_header(socket, header, headers[header])
@@ -533,6 +561,8 @@ def _send_request( # pylint: disable=too-many-arguments
533561
# Send data
534562
if data:
535563
self._send(socket, bytes(data))
564+
elif boundary_objects:
565+
self._send_boundary_objects(socket, boundary_objects)
536566

537567
# pylint: disable=too-many-branches, too-many-statements, unused-argument, too-many-arguments, too-many-locals
538568
def request(

examples/wifi/expanded/requests_wifi_file_upload.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,16 +12,16 @@
1212
ssl_context = adafruit_connection_manager.get_radio_ssl_context(wifi.radio)
1313
requests = adafruit_requests.Session(pool, ssl_context)
1414

15-
with open("raspi_snip.png", "rb") as file_handle:
15+
with open("requests_wifi_file_upload_image.png", "rb") as file_handle:
1616
files = {
1717
"file": (
18-
"raspi_snip.png",
18+
"requests_wifi_file_upload_image.png",
1919
file_handle,
2020
"image/png",
2121
{"CustomHeader": "BlinkaRocks"},
2222
),
2323
"othervalue": (None, "HelloWorld"),
2424
}
2525

26-
with requests.post(URL, files=files) as resp:
27-
print(resp.content)
26+
with requests.post(URL, files=files) as response:
27+
print(response.content)

tests/header_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
def test_check_headers_not_dict(requests):
1212
with pytest.raises(AttributeError) as context:
1313
requests._check_headers("")
14-
assert "headers must be in dict format" in str(context)
14+
assert "Headers must be in dict format" in str(context)
1515

1616

1717
def test_check_headers_not_valid(requests):

tests/method_test.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,10 @@ def test_post_string(sock, requests):
5252

5353

5454
def test_post_form(sock, requests):
55-
data = {"Date": "July 25, 2019", "Time": "12:00"}
55+
data = {
56+
"Date": "July 25, 2019",
57+
"Time": "12:00",
58+
}
5659
requests.post("http://" + mocket.MOCK_HOST_1 + "/post", data=data)
5760
sock.connect.assert_called_once_with((mocket.MOCK_POOL_IP, 80))
5861
sock.send.assert_has_calls(
@@ -67,7 +70,10 @@ def test_post_form(sock, requests):
6770

6871

6972
def test_post_json(sock, requests):
70-
json_data = {"Date": "July 25, 2019", "Time": "12:00"}
73+
json_data = {
74+
"Date": "July 25, 2019",
75+
"Time": "12:00",
76+
}
7177
requests.post("http://" + mocket.MOCK_HOST_1 + "/post", json=json_data)
7278
sock.connect.assert_called_once_with((mocket.MOCK_POOL_IP, 80))
7379
sock.send.assert_has_calls(

0 commit comments

Comments
 (0)