Skip to content

Commit 49d1177

Browse files
committed
ucloud: Switch to SSLContext.
Signed-off-by: iabdalkader <[email protected]>
1 parent a8dd722 commit 49d1177

File tree

2 files changed

+67
-63
lines changed

2 files changed

+67
-63
lines changed

src/arduino_iot_cloud/umqtt.py

+7-9
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
import select
2828
import logging
2929
import arduino_iot_cloud.ussl as ssl
30+
import sys
3031

3132

3233
class MQTTException(Exception):
@@ -92,17 +93,14 @@ def connect(self, clean_session=True, timeout=5.0):
9293
self.sock.close()
9394
self.sock = None
9495

95-
try:
96-
self.sock = socket.socket()
97-
self.sock.settimeout(timeout)
98-
self.sock = ssl.wrap_socket(self.sock, self.ssl_params)
99-
self.sock.connect(addr)
100-
except Exception:
101-
self.sock.close()
102-
self.sock = socket.socket()
103-
self.sock.settimeout(timeout)
96+
self.sock = socket.socket()
97+
self.sock.settimeout(timeout)
98+
if sys.implementation.name == "micropython":
10499
self.sock.connect(addr)
105100
self.sock = ssl.wrap_socket(self.sock, self.ssl_params)
101+
else:
102+
self.sock = ssl.wrap_socket(self.sock, self.ssl_params)
103+
self.sock.connect(addr)
106104

107105
premsg = bytearray(b"\x10\0\0\0\0\0")
108106
msg = bytearray(b"\x04MQTT\x04\x02\0\0")

src/arduino_iot_cloud/ussl.py

+60-54
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
#
77
# SSL module with m2crypto backend for HSM support.
88

9-
import sys
109
import ssl
1110

1211
pkcs11 = None
@@ -16,62 +15,69 @@
1615
_MODULE_PATH = "/usr/lib/softhsm/libsofthsm2.so"
1716

1817

19-
def wrap_socket(
20-
sock,
21-
ssl_params={},
22-
):
23-
if any(k not in ssl_params for k in ("keyfile", "certfile", "pin")):
24-
# Use Micro/CPython's SSL
25-
if sys.implementation.name == "micropython":
26-
# Load key, cert and CA from DER files, and pass them as binary blobs.
27-
mpargs = {"keyfile": "key", "certfile": "cert", "ca_certs": "cadata"}
28-
for k, v in mpargs.items():
29-
if k in ssl_params and "der" in ssl_params[k]:
30-
with open(ssl_params.pop(k), "rb") as f:
31-
ssl_params[v] = f.read()
32-
return ssl.wrap_socket(sock, **ssl_params)
33-
34-
# Use M2Crypto to load key and cert from HSM.
35-
from M2Crypto import m2, SSL, Engine
36-
37-
global pkcs11
38-
if pkcs11 is None:
39-
pkcs11 = Engine.load_dynamic_engine(
40-
"pkcs11", ssl_params.get("engine_path", _ENGINE_PATH)
41-
)
42-
pkcs11.ctrl_cmd_string(
43-
"MODULE_PATH", ssl_params.get("module_path", _MODULE_PATH)
44-
)
45-
pkcs11.ctrl_cmd_string("PIN", ssl_params["pin"])
46-
pkcs11.init()
47-
48-
# Create and configure SSL context
49-
ctx = SSL.Context("tls")
50-
ctx.set_default_verify_paths()
51-
ctx.set_allow_unknown_ca(False)
52-
18+
def wrap_socket(sock, ssl_params={}):
19+
keyfile = ssl_params.get("keyfile", None)
20+
certfile = ssl_params.get("certfile", None)
21+
cafile = ssl_params.get("cafile", None)
22+
cadata = ssl_params.get("cadata", None)
5323
ciphers = ssl_params.get("ciphers", None)
54-
if ciphers is not None:
55-
ctx.set_cipher_list(ciphers)
24+
verify = ssl_params.get("verify_mode", ssl.CERT_NONE)
25+
hostname = ssl_params.get("server_hostname", None)
26+
use_hsm = ssl_params.get("use_hsm", False)
5627

57-
ca_certs = ssl_params.get("ca_certs", None)
58-
if ca_certs is not None:
59-
if ctx.load_verify_locations(ca_certs) != 1:
60-
raise Exception("Failed to load CA certs")
61-
62-
cert_reqs = ssl_params.get("cert_reqs", ssl.CERT_NONE)
63-
if cert_reqs == ssl.CERT_NONE:
64-
cert_reqs = SSL.verify_none
28+
if not use_hsm:
29+
ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
30+
if hasattr(ctx, "set_default_verify_paths"):
31+
ctx.set_default_verify_paths()
32+
if verify != ssl.CERT_REQUIRED:
33+
ctx.check_hostname = False
34+
ctx.verify_mode = verify
35+
if keyfile is not None and certfile is not None:
36+
ctx.load_cert_chain(certfile, keyfile)
37+
if ciphers is not None:
38+
ctx.set_ciphers(ciphers)
39+
if cafile is not None or cadata is not None:
40+
ctx.load_verify_locations(cafile, cadata)
41+
return ctx.wrap_socket(sock, server_hostname=hostname)
6542
else:
66-
cert_reqs = SSL.verify_peer
67-
ctx.set_verify(cert_reqs, depth=9)
43+
# Use M2Crypto to load key and cert from HSM.
44+
from M2Crypto import m2, SSL, Engine
45+
46+
global pkcs11
47+
if pkcs11 is None:
48+
pkcs11 = Engine.load_dynamic_engine(
49+
"pkcs11", ssl_params.get("engine_path", _ENGINE_PATH)
50+
)
51+
pkcs11.ctrl_cmd_string(
52+
"MODULE_PATH", ssl_params.get("module_path", _MODULE_PATH)
53+
)
54+
if "pin" in ssl_params:
55+
pkcs11.ctrl_cmd_string("PIN", ssl_params["pin"])
56+
pkcs11.init()
57+
58+
# Create and configure SSL context
59+
ctx = SSL.Context("tls")
60+
ctx.set_default_verify_paths()
61+
ctx.set_allow_unknown_ca(False)
62+
if verify == ssl.CERT_NONE:
63+
ctx.set_verify(SSL.verify_none, depth=9)
64+
else:
65+
ctx.set_verify(SSL.verify_peer | SSL.verify_fail_if_no_peer_cert, depth=9)
66+
if cafile is not None:
67+
if ctx.load_verify_locations(cafile) != 1:
68+
raise Exception("Failed to load CA certs")
69+
if ciphers is not None:
70+
ctx.set_cipher_list(ciphers)
6871

69-
# Set key/cert
70-
key = pkcs11.load_private_key(ssl_params["keyfile"])
71-
m2.ssl_ctx_use_pkey_privkey(ctx.ctx, key.pkey)
72+
key = pkcs11.load_private_key(keyfile)
73+
m2.ssl_ctx_use_pkey_privkey(ctx.ctx, key.pkey)
7274

73-
cert = pkcs11.load_certificate(ssl_params["certfile"])
74-
m2.ssl_ctx_use_x509(ctx.ctx, cert.x509)
75+
cert = pkcs11.load_certificate(certfile)
76+
m2.ssl_ctx_use_x509(ctx.ctx, cert.x509)
7577

76-
SSL.Connection.postConnectionCheck = None
77-
return SSL.Connection(ctx, sock=sock)
78+
sslobj = SSL.Connection(ctx, sock=sock)
79+
if verify == ssl.CERT_NONE:
80+
sslobj.clientPostConnectionCheck = None
81+
elif hostname is not None:
82+
sslobj.set1_host(hostname)
83+
return sslobj

0 commit comments

Comments
 (0)