Skip to content

Commit 6782090

Browse files
committed
NetworkClientSecure made copyable
1 parent e8e251a commit 6782090

File tree

4 files changed

+30
-30
lines changed

4 files changed

+30
-30
lines changed

Diff for: libraries/NetworkClientSecure/src/NetworkClientSecure.cpp

+24-25
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,12 @@ NetworkClientSecure::NetworkClientSecure() {
3232
_connected = false;
3333
_timeout = 30000; // Same default as ssl_client
3434

35-
sslclient = new sslclient_context;
36-
ssl_init(sslclient);
35+
sslclient.reset(new sslclient_context, [](struct sslclient_context *sslclient) {
36+
stop_ssl_socket(sslclient);
37+
delete sslclient;
38+
39+
});
40+
ssl_init(sslclient.get());
3741
sslclient->socket = -1;
3842
sslclient->handshake_timeout = 120000;
3943
_use_insecure = false;
@@ -53,8 +57,12 @@ NetworkClientSecure::NetworkClientSecure(int sock) {
5357
_lastReadTimeout = 0;
5458
_lastWriteTimeout = 0;
5559

56-
sslclient = new sslclient_context;
57-
ssl_init(sslclient);
60+
sslclient.reset(new sslclient_context, [](struct sslclient_context *sslclient) {
61+
stop_ssl_socket(sslclient);
62+
delete sslclient;
63+
64+
});
65+
ssl_init(sslclient.get());
5866
sslclient->socket = sock;
5967
sslclient->handshake_timeout = 120000;
6068

@@ -72,19 +80,10 @@ NetworkClientSecure::NetworkClientSecure(int sock) {
7280
}
7381

7482
NetworkClientSecure::~NetworkClientSecure() {
75-
stop();
76-
delete sslclient;
77-
}
78-
79-
NetworkClientSecure &NetworkClientSecure::operator=(const NetworkClientSecure &other) {
80-
stop();
81-
sslclient->socket = other.sslclient->socket;
82-
_connected = other._connected;
83-
return *this;
8483
}
8584

8685
void NetworkClientSecure::stop() {
87-
stop_ssl_socket(sslclient, _CA_cert, _cert, _private_key);
86+
stop_ssl_socket(sslclient.get());
8887

8988
_connected = false;
9089
_peek = -1;
@@ -130,10 +129,10 @@ int NetworkClientSecure::connect(const char *host, uint16_t port, const char *CA
130129
}
131130

132131
int NetworkClientSecure::connect(IPAddress ip, uint16_t port, const char *host, const char *CA_cert, const char *cert, const char *private_key) {
133-
int ret = start_ssl_client(sslclient, ip, port, host, _timeout, CA_cert, _use_ca_bundle, cert, private_key, NULL, NULL, _use_insecure, _alpn_protos);
132+
int ret = start_ssl_client(sslclient.get(), ip, port, host, _timeout, CA_cert, _use_ca_bundle, cert, private_key, NULL, NULL, _use_insecure, _alpn_protos);
134133

135134
if (ret >= 0 && !_stillinPlainStart) {
136-
ret = ssl_starttls_handshake(sslclient);
135+
ret = ssl_starttls_handshake(sslclient.get());
137136
} else {
138137
log_i("Actual TLS start postponed.");
139138
}
@@ -153,7 +152,7 @@ int NetworkClientSecure::startTLS() {
153152
int ret = 1;
154153
if (_stillinPlainStart) {
155154
log_i("startTLS: starting TLS/SSL on this dplain connection");
156-
ret = ssl_starttls_handshake(sslclient);
155+
ret = ssl_starttls_handshake(sslclient.get());
157156
if (ret < 0) {
158157
log_e("startTLS: %d", ret);
159158
stop();
@@ -178,7 +177,7 @@ int NetworkClientSecure::connect(const char *host, uint16_t port, const char *ps
178177
return 0;
179178
}
180179

181-
int ret = start_ssl_client(sslclient, address, port, host, _timeout, NULL, false, NULL, NULL, pskIdent, psKey, _use_insecure, _alpn_protos);
180+
int ret = start_ssl_client(sslclient.get(), address, port, host, _timeout, NULL, false, NULL, NULL, pskIdent, psKey, _use_insecure, _alpn_protos);
182181
_lastError = ret;
183182
if (ret < 0) {
184183
log_e("start_ssl_client: connect failed %d", ret);
@@ -213,7 +212,7 @@ size_t NetworkClientSecure::write(const uint8_t *buf, size_t size) {
213212
}
214213

215214
if (_stillinPlainStart) {
216-
return send_net_data(sslclient, buf, size);
215+
return send_net_data(sslclient.get(), buf, size);
217216
}
218217

219218
if (_lastWriteTimeout != _timeout) {
@@ -224,7 +223,7 @@ size_t NetworkClientSecure::write(const uint8_t *buf, size_t size) {
224223
_lastWriteTimeout = _timeout;
225224
}
226225
}
227-
int res = send_ssl_data(sslclient, buf, size);
226+
int res = send_ssl_data(sslclient.get(), buf, size);
228227
if (res < 0) {
229228
log_e("Closing connection on failed write");
230229
stop();
@@ -235,7 +234,7 @@ size_t NetworkClientSecure::write(const uint8_t *buf, size_t size) {
235234

236235
int NetworkClientSecure::read(uint8_t *buf, size_t size) {
237236
if (_stillinPlainStart) {
238-
return get_net_receive(sslclient, buf, size);
237+
return get_net_receive(sslclient.get(), buf, size);
239238
}
240239

241240
if (_lastReadTimeout != _timeout) {
@@ -268,7 +267,7 @@ int NetworkClientSecure::read(uint8_t *buf, size_t size) {
268267
buf++;
269268
peeked = 1;
270269
}
271-
res = get_ssl_receive(sslclient, buf, size);
270+
res = get_ssl_receive(sslclient.get(), buf, size);
272271

273272
if (res < 0) {
274273
log_e("Closing connection on failed read");
@@ -280,14 +279,14 @@ int NetworkClientSecure::read(uint8_t *buf, size_t size) {
280279

281280
int NetworkClientSecure::available() {
282281
if (_stillinPlainStart) {
283-
return peek_net_receive(sslclient, 0);
282+
return peek_net_receive(sslclient.get(), 0);
284283
}
285284

286285
int peeked = (_peek >= 0), res = -1;
287286
if (!_connected) {
288287
return peeked;
289288
}
290-
res = data_to_read(sslclient);
289+
res = data_to_read(sslclient.get());
291290

292291
if (res < 0 && !_stillinPlainStart) {
293292
log_e("Closing connection on failed available check");
@@ -346,7 +345,7 @@ bool NetworkClientSecure::verify(const char *fp, const char *domain_name) {
346345
return false;
347346
}
348347

349-
return verify_ssl_fingerprint(sslclient, fp, domain_name);
348+
return verify_ssl_fingerprint(sslclient.get(), fp, domain_name);
350349
}
351350

352351
char *NetworkClientSecure::_streamLoad(Stream &stream, size_t size) {

Diff for: libraries/NetworkClientSecure/src/NetworkClientSecure.h

+4-3
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,11 @@
2424
#include "IPAddress.h"
2525
#include "Network.h"
2626
#include "ssl_client.h"
27+
#include <memory>
2728

2829
class NetworkClientSecure : public NetworkClient {
2930
protected:
30-
sslclient_context *sslclient;
31+
std::shared_ptr<sslclient_context> sslclient;
3132

3233
int _lastError = 0;
3334
int _peek = -1;
@@ -97,14 +98,14 @@ class NetworkClientSecure : public NetworkClient {
9798
return mbedtls_ssl_get_peer_cert(&sslclient->ssl_ctx);
9899
};
99100
bool getFingerprintSHA256(uint8_t sha256_result[32]) {
100-
return get_peer_fingerprint(sslclient, sha256_result);
101+
return get_peer_fingerprint(sslclient.get(), sha256_result);
101102
};
102103
int fd() const;
103104

104105
operator bool() {
105106
return connected();
106107
}
107-
NetworkClientSecure &operator=(const NetworkClientSecure &other);
108+
108109
bool operator==(const bool value) {
109110
return bool() == value;
110111
}

Diff for: libraries/NetworkClientSecure/src/ssl_client.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -344,7 +344,7 @@ int ssl_starttls_handshake(sslclient_context *ssl_client) {
344344
return ssl_client->socket;
345345
}
346346

347-
void stop_ssl_socket(sslclient_context *ssl_client, const char *rootCABuff, const char *cli_cert, const char *cli_key) {
347+
void stop_ssl_socket(sslclient_context *ssl_client) {
348348
log_v("Cleaning SSL connection.");
349349

350350
if (ssl_client->socket >= 0) {

Diff for: libraries/NetworkClientSecure/src/ssl_client.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ int start_ssl_client(
3434
const char *cli_cert, const char *cli_key, const char *pskIdent, const char *psKey, bool insecure, const char **alpn_protos
3535
);
3636
int ssl_starttls_handshake(sslclient_context *ssl_client);
37-
void stop_ssl_socket(sslclient_context *ssl_client, const char *rootCABuff, const char *cli_cert, const char *cli_key);
37+
void stop_ssl_socket(sslclient_context *ssl_client);
3838
int data_to_read(sslclient_context *ssl_client);
3939
int send_ssl_data(sslclient_context *ssl_client, const uint8_t *data, size_t len);
4040
int get_ssl_receive(sslclient_context *ssl_client, uint8_t *data, int length);

0 commit comments

Comments
 (0)