Skip to content

NetworkClientSecure made copyable #9612

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
May 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 23 additions & 27 deletions libraries/NetworkClientSecure/src/NetworkClientSecure.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,11 @@ NetworkClientSecure::NetworkClientSecure() {
_connected = false;
_timeout = 30000; // Same default as ssl_client

sslclient = new sslclient_context;
ssl_init(sslclient);
sslclient.reset(new sslclient_context, [](struct sslclient_context *sslclient) {
stop_ssl_socket(sslclient);
delete sslclient;
});
ssl_init(sslclient.get());
sslclient->socket = -1;
sslclient->handshake_timeout = 120000;
_use_insecure = false;
Expand All @@ -53,8 +56,11 @@ NetworkClientSecure::NetworkClientSecure(int sock) {
_lastReadTimeout = 0;
_lastWriteTimeout = 0;

sslclient = new sslclient_context;
ssl_init(sslclient);
sslclient.reset(new sslclient_context, [](struct sslclient_context *sslclient) {
stop_ssl_socket(sslclient);
delete sslclient;
});
ssl_init(sslclient.get());
sslclient->socket = sock;
sslclient->handshake_timeout = 120000;

Expand All @@ -71,20 +77,10 @@ NetworkClientSecure::NetworkClientSecure(int sock) {
_alpn_protos = NULL;
}

NetworkClientSecure::~NetworkClientSecure() {
stop();
delete sslclient;
}

NetworkClientSecure &NetworkClientSecure::operator=(const NetworkClientSecure &other) {
stop();
sslclient->socket = other.sslclient->socket;
_connected = other._connected;
return *this;
}
NetworkClientSecure::~NetworkClientSecure() {}

void NetworkClientSecure::stop() {
stop_ssl_socket(sslclient, _CA_cert, _cert, _private_key);
stop_ssl_socket(sslclient.get());

_connected = false;
_peek = -1;
Expand Down Expand Up @@ -130,10 +126,10 @@ int NetworkClientSecure::connect(const char *host, uint16_t port, const char *CA
}

int NetworkClientSecure::connect(IPAddress ip, uint16_t port, const char *host, const char *CA_cert, const char *cert, const char *private_key) {
int ret = start_ssl_client(sslclient, ip, port, host, _timeout, CA_cert, _use_ca_bundle, cert, private_key, NULL, NULL, _use_insecure, _alpn_protos);
int ret = start_ssl_client(sslclient.get(), ip, port, host, _timeout, CA_cert, _use_ca_bundle, cert, private_key, NULL, NULL, _use_insecure, _alpn_protos);

if (ret >= 0 && !_stillinPlainStart) {
ret = ssl_starttls_handshake(sslclient);
ret = ssl_starttls_handshake(sslclient.get());
} else {
log_i("Actual TLS start postponed.");
}
Expand All @@ -153,7 +149,7 @@ int NetworkClientSecure::startTLS() {
int ret = 1;
if (_stillinPlainStart) {
log_i("startTLS: starting TLS/SSL on this dplain connection");
ret = ssl_starttls_handshake(sslclient);
ret = ssl_starttls_handshake(sslclient.get());
if (ret < 0) {
log_e("startTLS: %d", ret);
stop();
Expand All @@ -178,7 +174,7 @@ int NetworkClientSecure::connect(const char *host, uint16_t port, const char *ps
return 0;
}

int ret = start_ssl_client(sslclient, address, port, host, _timeout, NULL, false, NULL, NULL, pskIdent, psKey, _use_insecure, _alpn_protos);
int ret = start_ssl_client(sslclient.get(), address, port, host, _timeout, NULL, false, NULL, NULL, pskIdent, psKey, _use_insecure, _alpn_protos);
_lastError = ret;
if (ret < 0) {
log_e("start_ssl_client: connect failed %d", ret);
Expand Down Expand Up @@ -213,7 +209,7 @@ size_t NetworkClientSecure::write(const uint8_t *buf, size_t size) {
}

if (_stillinPlainStart) {
return send_net_data(sslclient, buf, size);
return send_net_data(sslclient.get(), buf, size);
}

if (_lastWriteTimeout != _timeout) {
Expand All @@ -224,7 +220,7 @@ size_t NetworkClientSecure::write(const uint8_t *buf, size_t size) {
_lastWriteTimeout = _timeout;
}
}
int res = send_ssl_data(sslclient, buf, size);
int res = send_ssl_data(sslclient.get(), buf, size);
if (res < 0) {
log_e("Closing connection on failed write");
stop();
Expand All @@ -235,7 +231,7 @@ size_t NetworkClientSecure::write(const uint8_t *buf, size_t size) {

int NetworkClientSecure::read(uint8_t *buf, size_t size) {
if (_stillinPlainStart) {
return get_net_receive(sslclient, buf, size);
return get_net_receive(sslclient.get(), buf, size);
}

if (_lastReadTimeout != _timeout) {
Expand Down Expand Up @@ -268,7 +264,7 @@ int NetworkClientSecure::read(uint8_t *buf, size_t size) {
buf++;
peeked = 1;
}
res = get_ssl_receive(sslclient, buf, size);
res = get_ssl_receive(sslclient.get(), buf, size);

if (res < 0) {
log_e("Closing connection on failed read");
Expand All @@ -280,14 +276,14 @@ int NetworkClientSecure::read(uint8_t *buf, size_t size) {

int NetworkClientSecure::available() {
if (_stillinPlainStart) {
return peek_net_receive(sslclient, 0);
return peek_net_receive(sslclient.get(), 0);
}

int peeked = (_peek >= 0), res = -1;
if (!_connected) {
return peeked;
}
res = data_to_read(sslclient);
res = data_to_read(sslclient.get());

if (res < 0 && !_stillinPlainStart) {
log_e("Closing connection on failed available check");
Expand Down Expand Up @@ -346,7 +342,7 @@ bool NetworkClientSecure::verify(const char *fp, const char *domain_name) {
return false;
}

return verify_ssl_fingerprint(sslclient, fp, domain_name);
return verify_ssl_fingerprint(sslclient.get(), fp, domain_name);
}

char *NetworkClientSecure::_streamLoad(Stream &stream, size_t size) {
Expand Down
7 changes: 4 additions & 3 deletions libraries/NetworkClientSecure/src/NetworkClientSecure.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,11 @@
#include "IPAddress.h"
#include "Network.h"
#include "ssl_client.h"
#include <memory>

class NetworkClientSecure : public NetworkClient {
protected:
sslclient_context *sslclient;
std::shared_ptr<sslclient_context> sslclient;

int _lastError = 0;
int _peek = -1;
Expand Down Expand Up @@ -97,14 +98,14 @@ class NetworkClientSecure : public NetworkClient {
return mbedtls_ssl_get_peer_cert(&sslclient->ssl_ctx);
};
bool getFingerprintSHA256(uint8_t sha256_result[32]) {
return get_peer_fingerprint(sslclient, sha256_result);
return get_peer_fingerprint(sslclient.get(), sha256_result);
};
int fd() const;

operator bool() {
return connected();
}
NetworkClientSecure &operator=(const NetworkClientSecure &other);

bool operator==(const bool value) {
return bool() == value;
}
Expand Down
2 changes: 1 addition & 1 deletion libraries/NetworkClientSecure/src/ssl_client.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,7 @@ int ssl_starttls_handshake(sslclient_context *ssl_client) {
return ssl_client->socket;
}

void stop_ssl_socket(sslclient_context *ssl_client, const char *rootCABuff, const char *cli_cert, const char *cli_key) {
void stop_ssl_socket(sslclient_context *ssl_client) {
log_v("Cleaning SSL connection.");

if (ssl_client->socket >= 0) {
Expand Down
2 changes: 1 addition & 1 deletion libraries/NetworkClientSecure/src/ssl_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ int start_ssl_client(
const char *cli_cert, const char *cli_key, const char *pskIdent, const char *psKey, bool insecure, const char **alpn_protos
);
int ssl_starttls_handshake(sslclient_context *ssl_client);
void stop_ssl_socket(sslclient_context *ssl_client, const char *rootCABuff, const char *cli_cert, const char *cli_key);
void stop_ssl_socket(sslclient_context *ssl_client);
int data_to_read(sslclient_context *ssl_client);
int send_ssl_data(sslclient_context *ssl_client, const uint8_t *data, size_t len);
int get_ssl_receive(sslclient_context *ssl_client, uint8_t *data, int length);
Expand Down