Skip to content

Fix WiFiClientSecure memory leaks when the connection fails (certificate verification, handshake, etc.) #5944

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 3 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
179 changes: 60 additions & 119 deletions libraries/WiFiClientSecure/src/ssl_client.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,23 +19,20 @@
#include "ssl_client.h"
#include "WiFi.h"

#ifndef MBEDTLS_KEY_EXCHANGE__SOME__PSK_ENABLED
# warning "Please configure IDF framework to include mbedTLS -> Enable pre-shared-key ciphersuites and activate at least one cipher"
#else

const char *pers = "esp32-tls";

static int _handle_error(int err, const char * function, int line)
static int _handle_error(int err, const char * file, int line)
{
if(err == -30848){
return err;
}
#ifdef MBEDTLS_ERROR_C
char error_buf[100];
mbedtls_strerror(err, error_buf, 100);
log_e("[%s():%d]: (%d) %s", function, line, err, error_buf);
log_e("[%s():%d]: (%d) %s", file, line, err, error_buf);
#else
log_e("[%s():%d]: code %d", function, line, err);
log_e("[%s():%d]: code %d", file, line, err);
#endif
return err;
}
Expand All @@ -45,23 +42,21 @@ static int _handle_error(int err, const char * function, int line)

void ssl_init(sslclient_context *ssl_client)
{
// reset embedded pointers to zero
memset(ssl_client, 0, sizeof(sslclient_context));
mbedtls_ssl_init(&ssl_client->ssl_ctx);
mbedtls_ssl_config_init(&ssl_client->ssl_conf);
mbedtls_ctr_drbg_init(&ssl_client->drbg_ctx);
}


int start_ssl_client(sslclient_context *ssl_client, const char *host, uint32_t port, int timeout, const char *rootCABuff, const char *cli_cert, const char *cli_key, const char *pskIdent, const char *psKey, bool insecure, const char **alpn_protos)
int start_ssl_client(sslclient_context *ssl_client, const char *host, uint32_t port, int timeout, const char *rootCABuff, const char *cli_cert, const char *cli_key, const char *pskIdent, const char *psKey)
{
char buf[512];
int ret, flags;
int enable = 1;
log_v("Free internal heap before TLS %u", ESP.getFreeHeap());

if (rootCABuff == NULL && pskIdent == NULL && psKey == NULL && !insecure) {
return -1;
}

log_v("Starting socket");
ssl_client->socket = -1;

Expand All @@ -76,67 +71,26 @@ int start_ssl_client(sslclient_context *ssl_client, const char *host, uint32_t p
return -1;
}

fcntl( ssl_client->socket, F_SETFL, fcntl( ssl_client->socket, F_GETFL, 0 ) | O_NONBLOCK );
struct sockaddr_in serv_addr;
memset(&serv_addr, 0, sizeof(serv_addr));
serv_addr.sin_family = AF_INET;
serv_addr.sin_addr.s_addr = srv;
serv_addr.sin_port = htons(port);

if(timeout <= 0){
timeout = 30000; // Milli seconds.
}

fd_set fdset;
struct timeval tv;
FD_ZERO(&fdset);
FD_SET(ssl_client->socket, &fdset);
tv.tv_sec = timeout / 1000;
tv.tv_usec = (timeout % 1000) * 1000;

int res = lwip_connect(ssl_client->socket, (struct sockaddr*)&serv_addr, sizeof(serv_addr));
if (res < 0 && errno != EINPROGRESS) {
log_e("connect on fd %d, errno: %d, \"%s\"", ssl_client->socket, errno, strerror(errno));
close(ssl_client->socket);
return -1;
}

res = select(ssl_client->socket + 1, nullptr, &fdset, nullptr, timeout<0 ? nullptr : &tv);
if (res < 0) {
log_e("select on fd %d, errno: %d, \"%s\"", ssl_client->socket, errno, strerror(errno));
close(ssl_client->socket);
return -1;
} else if (res == 0) {
log_i("select returned due to timeout %d ms for fd %d", timeout, ssl_client->socket);
close(ssl_client->socket);
return -1;
} else {
int sockerr;
socklen_t len = (socklen_t)sizeof(int);
res = getsockopt(ssl_client->socket, SOL_SOCKET, SO_ERROR, &sockerr, &len);

if (res < 0) {
log_e("getsockopt on fd %d, errno: %d, \"%s\"", ssl_client->socket, errno, strerror(errno));
close(ssl_client->socket);
return -1;
}

if (sockerr != 0) {
log_e("socket error on fd %d, errno: %d, \"%s\"", ssl_client->socket, sockerr, strerror(sockerr));
close(ssl_client->socket);
return -1;
if (lwip_connect(ssl_client->socket, (struct sockaddr *)&serv_addr, sizeof(serv_addr)) == 0) {
if(timeout <= 0){
timeout = 30000;
}
lwip_setsockopt(ssl_client->socket, SOL_SOCKET, SO_RCVTIMEO, &timeout, sizeof(timeout));
lwip_setsockopt(ssl_client->socket, SOL_SOCKET, SO_SNDTIMEO, &timeout, sizeof(timeout));
lwip_setsockopt(ssl_client->socket, IPPROTO_TCP, TCP_NODELAY, &enable, sizeof(enable));
lwip_setsockopt(ssl_client->socket, SOL_SOCKET, SO_KEEPALIVE, &enable, sizeof(enable));
} else {
log_e("Connect to Server failed!");
return -1;
}


#define ROE(x,msg) { if (((x)<0)) { log_e("LWIP Socket config of " msg " failed."); return -1; }}
ROE(lwip_setsockopt(ssl_client->socket, SOL_SOCKET, SO_RCVTIMEO, &tv, sizeof(tv)),"SO_RCVTIMEO");
ROE(lwip_setsockopt(ssl_client->socket, SOL_SOCKET, SO_SNDTIMEO, &tv, sizeof(tv)),"SO_SNDTIMEO");

ROE(lwip_setsockopt(ssl_client->socket, IPPROTO_TCP, TCP_NODELAY, &enable, sizeof(enable)),"TCP_NODELAY");
ROE(lwip_setsockopt(ssl_client->socket, SOL_SOCKET, SO_KEEPALIVE, &enable, sizeof(enable)),"SO_KEEPALIVE");


fcntl( ssl_client->socket, F_SETFL, fcntl( ssl_client->socket, F_GETFL, 0 ) | O_NONBLOCK );

log_v("Seeding the random number generator");
mbedtls_entropy_init(&ssl_client->entropy_ctx);
Expand All @@ -156,29 +110,17 @@ int start_ssl_client(sslclient_context *ssl_client, const char *host, uint32_t p
return handle_error(ret);
}

if (alpn_protos != NULL) {
log_v("Setting ALPN protocols");
if ((ret = mbedtls_ssl_conf_alpn_protocols(&ssl_client->ssl_conf, alpn_protos) ) != 0) {
return handle_error(ret);
}
}

// MBEDTLS_SSL_VERIFY_REQUIRED if a CA certificate is defined on Arduino IDE and
// MBEDTLS_SSL_VERIFY_NONE if not.

if (insecure) {
mbedtls_ssl_conf_authmode(&ssl_client->ssl_conf, MBEDTLS_SSL_VERIFY_NONE);
log_i("WARNING: Skipping SSL Verification. INSECURE!");
} else if (rootCABuff != NULL) {
if (rootCABuff != NULL) {
log_v("Loading CA cert");
mbedtls_x509_crt_init(&ssl_client->ca_cert);
mbedtls_ssl_conf_authmode(&ssl_client->ssl_conf, MBEDTLS_SSL_VERIFY_REQUIRED);
ret = mbedtls_x509_crt_parse(&ssl_client->ca_cert, (const unsigned char *)rootCABuff, strlen(rootCABuff) + 1);
mbedtls_ssl_conf_ca_chain(&ssl_client->ssl_conf, &ssl_client->ca_cert, NULL);
//mbedtls_ssl_conf_verify(&ssl_client->ssl_ctx, my_verify, NULL );
if (ret < 0) {
// free the ca_cert in the case parse failed, otherwise, the old ca_cert still in the heap memory, that lead to "out of memory" crash.
mbedtls_x509_crt_free(&ssl_client->ca_cert);
return handle_error(ret);
}
} else if (pskIdent != NULL && psKey != NULL) {
Expand Down Expand Up @@ -212,26 +154,26 @@ int start_ssl_client(sslclient_context *ssl_client, const char *host, uint32_t p
return handle_error(ret);
}
} else {
return -1;
mbedtls_ssl_conf_authmode(&ssl_client->ssl_conf, MBEDTLS_SSL_VERIFY_NONE);
log_i("WARNING: Use certificates for a more secure communication!");
}

if (!insecure && cli_cert != NULL && cli_key != NULL) {
if (cli_cert != NULL && cli_key != NULL) {
mbedtls_x509_crt_init(&ssl_client->client_cert);
mbedtls_pk_init(&ssl_client->client_key);

log_v("Loading CRT cert");

ret = mbedtls_x509_crt_parse(&ssl_client->client_cert, (const unsigned char *)cli_cert, strlen(cli_cert) + 1);
if (ret < 0) {
// free the client_cert in the case parse failed, otherwise, the old client_cert still in the heap memory, that lead to "out of memory" crash.
mbedtls_x509_crt_free(&ssl_client->client_cert);
return handle_error(ret);
}

log_v("Loading private key");
ret = mbedtls_pk_parse_key(&ssl_client->client_key, (const unsigned char *)cli_key, strlen(cli_key) + 1, NULL, 0);

if (ret != 0) {
mbedtls_x509_crt_free(&ssl_client->client_cert); // cert+key are free'd in pair
return handle_error(ret);
}

Expand All @@ -243,7 +185,7 @@ int start_ssl_client(sslclient_context *ssl_client, const char *host, uint32_t p
// Hostname set here should match CN in server certificate
if((ret = mbedtls_ssl_set_hostname(&ssl_client->ssl_ctx, host)) != 0){
return handle_error(ret);
}
}

mbedtls_ssl_conf_rng(&ssl_client->ssl_conf, mbedtls_ctr_drbg_random, &ssl_client->drbg_ctx);

Expand All @@ -257,11 +199,12 @@ int start_ssl_client(sslclient_context *ssl_client, const char *host, uint32_t p
unsigned long handshake_start_time=millis();
while ((ret = mbedtls_ssl_handshake(&ssl_client->ssl_ctx)) != 0) {
if (ret != MBEDTLS_ERR_SSL_WANT_READ && ret != MBEDTLS_ERR_SSL_WANT_WRITE) {
log_d("Verify result %04X", mbedtls_ssl_get_verify_result(&ssl_client->ssl_ctx));
return handle_error(ret);
}
if((millis()-handshake_start_time)>ssl_client->handshake_timeout)
return -1;
vTaskDelay(2);//2 ticks
return -1;
vTaskDelay(10 / portTICK_PERIOD_MS);
}


Expand All @@ -277,10 +220,10 @@ int start_ssl_client(sslclient_context *ssl_client, const char *host, uint32_t p
log_v("Verifying peer X.509 certificate...");

if ((flags = mbedtls_ssl_get_verify_result(&ssl_client->ssl_ctx)) != 0) {
memset(buf, 0, sizeof(buf));
bzero(buf, sizeof(buf));
mbedtls_x509_crt_verify_info(buf, sizeof(buf), " ! ", flags);
log_e("Failed to verify peer certificate! verification info: %s", buf);
stop_ssl_socket(ssl_client, rootCABuff, cli_cert, cli_key); //It's not safe continue.
//It's not safe continue.
return handle_error(ret);
} else {
log_v("Certificate verified.");
Expand Down Expand Up @@ -312,11 +255,20 @@ void stop_ssl_socket(sslclient_context *ssl_client, const char *rootCABuff, cons
close(ssl_client->socket);
ssl_client->socket = -1;
}

// avoid memory leak if ssl connection attempt failed
if (ssl_client->ssl_conf.ca_chain != NULL) {
mbedtls_x509_crt_free(&ssl_client->ca_cert);
}
if (ssl_client->ssl_conf.key_cert != NULL) {
mbedtls_x509_crt_free(&ssl_client->client_cert);
mbedtls_pk_free(&ssl_client->client_key);
}
mbedtls_ssl_free(&ssl_client->ssl_ctx);
mbedtls_ssl_config_free(&ssl_client->ssl_conf);
mbedtls_ctr_drbg_free(&ssl_client->drbg_ctx);
mbedtls_entropy_free(&ssl_client->entropy_ctx);
// reset embedded pointers to zero
memset(ssl_client, 0, sizeof(sslclient_context));
}


Expand All @@ -334,26 +286,27 @@ int data_to_read(sslclient_context *ssl_client)
return res;
}

int send_ssl_data(sslclient_context *ssl_client, const uint8_t *data, size_t len)

int send_ssl_data(sslclient_context *ssl_client, const uint8_t *data, uint16_t len)
{
log_v("Writing HTTP request with %d bytes...", len); //for low level debug
log_v("Writing SSL data..."); //for low level debug
int ret = -1;

while ((ret = mbedtls_ssl_write(&ssl_client->ssl_ctx, data, len)) <= 0) {
if (ret != MBEDTLS_ERR_SSL_WANT_READ && ret != MBEDTLS_ERR_SSL_WANT_WRITE && ret < 0) {
log_v("Handling error %d", ret); //for low level debug
if (ret != MBEDTLS_ERR_SSL_WANT_READ && ret != MBEDTLS_ERR_SSL_WANT_WRITE) {
return handle_error(ret);
}
//wait for space to become available
vTaskDelay(2);
}

len = ret;
//log_v("%d bytes written", len); //for low level debug
return ret;
}


int get_ssl_receive(sslclient_context *ssl_client, uint8_t *data, int length)
{
//log_d( "Reading HTTP response..."); //for low level debug
log_v( "Reading SSL data..."); //for low level debug
int ret = -1;

ret = mbedtls_ssl_read(&ssl_client->ssl_ctx, data, length);
Expand Down Expand Up @@ -425,10 +378,22 @@ bool verify_ssl_fingerprint(sslclient_context *ssl_client, const char* fp, const
fingerprint_local[i] = low | (high << 4);
}

// Get certificate provided by the peer
const mbedtls_x509_crt* crt = mbedtls_ssl_get_peer_cert(&ssl_client->ssl_ctx);

if (!crt)
{
log_d("could not fetch peer certificate");
return false;
}

// Calculate certificate's SHA256 fingerprint
uint8_t fingerprint_remote[32];
if(!get_peer_fingerprint(ssl_client, fingerprint_remote))
return false;
mbedtls_sha256_context sha256_ctx;
mbedtls_sha256_init(&sha256_ctx);
mbedtls_sha256_starts(&sha256_ctx, false);
mbedtls_sha256_update(&sha256_ctx, crt->raw.p, crt->raw.len);
mbedtls_sha256_finish(&sha256_ctx, fingerprint_remote);

// Check if fingerprints match
if (memcmp(fingerprint_local, fingerprint_remote, 32))
Expand All @@ -444,28 +409,6 @@ bool verify_ssl_fingerprint(sslclient_context *ssl_client, const char* fp, const
return true;
}

bool get_peer_fingerprint(sslclient_context *ssl_client, uint8_t sha256[32])
{
if (!ssl_client) {
log_d("Invalid ssl_client pointer");
return false;
};

const mbedtls_x509_crt* crt = mbedtls_ssl_get_peer_cert(&ssl_client->ssl_ctx);
if (!crt) {
log_d("Failed to get peer cert.");
return false;
};

mbedtls_sha256_context sha256_ctx;
mbedtls_sha256_init(&sha256_ctx);
mbedtls_sha256_starts(&sha256_ctx, false);
mbedtls_sha256_update(&sha256_ctx, crt->raw.p, crt->raw.len);
mbedtls_sha256_finish(&sha256_ctx, sha256);

return true;
}

// Checks if peer certificate has specified domain in CN or SANs
bool verify_ssl_dn(sslclient_context *ssl_client, const char* domain_name)
{
Expand Down Expand Up @@ -513,5 +456,3 @@ bool verify_ssl_dn(sslclient_context *ssl_client, const char* domain_name)

return false;
}
#endif