diff --git a/libraries/WiFiClientSecure/src/ssl_client.cpp b/libraries/WiFiClientSecure/src/ssl_client.cpp index c910206b3c9..438817f7cbf 100644 --- a/libraries/WiFiClientSecure/src/ssl_client.cpp +++ b/libraries/WiFiClientSecure/src/ssl_client.cpp @@ -19,13 +19,10 @@ #include "ssl_client.h" #include "WiFi.h" -#ifndef MBEDTLS_KEY_EXCHANGE__SOME__PSK_ENABLED -# warning "Please configure IDF framework to include mbedTLS -> Enable pre-shared-key ciphersuites and activate at least one cipher" -#else const char *pers = "esp32-tls"; -static int _handle_error(int err, const char * function, int line) +static int _handle_error(int err, const char * file, int line) { if(err == -30848){ return err; @@ -33,9 +30,9 @@ static int _handle_error(int err, const char * function, int line) #ifdef MBEDTLS_ERROR_C char error_buf[100]; mbedtls_strerror(err, error_buf, 100); - log_e("[%s():%d]: (%d) %s", function, line, err, error_buf); + log_e("[%s():%d]: (%d) %s", file, line, err, error_buf); #else - log_e("[%s():%d]: code %d", function, line, err); + log_e("[%s():%d]: code %d", file, line, err); #endif return err; } @@ -45,23 +42,21 @@ static int _handle_error(int err, const char * function, int line) void ssl_init(sslclient_context *ssl_client) { + // reset embedded pointers to zero + memset(ssl_client, 0, sizeof(sslclient_context)); mbedtls_ssl_init(&ssl_client->ssl_ctx); mbedtls_ssl_config_init(&ssl_client->ssl_conf); mbedtls_ctr_drbg_init(&ssl_client->drbg_ctx); } -int start_ssl_client(sslclient_context *ssl_client, const char *host, uint32_t port, int timeout, const char *rootCABuff, const char *cli_cert, const char *cli_key, const char *pskIdent, const char *psKey, bool insecure, const char **alpn_protos) +int start_ssl_client(sslclient_context *ssl_client, const char *host, uint32_t port, int timeout, const char *rootCABuff, const char *cli_cert, const char *cli_key, const char *pskIdent, const char *psKey) { char buf[512]; int ret, flags; int enable = 1; log_v("Free internal heap before TLS %u", ESP.getFreeHeap()); - if (rootCABuff == NULL && pskIdent == NULL && psKey == NULL && !insecure) { - return -1; - } - log_v("Starting socket"); ssl_client->socket = -1; @@ -76,67 +71,26 @@ int start_ssl_client(sslclient_context *ssl_client, const char *host, uint32_t p return -1; } - fcntl( ssl_client->socket, F_SETFL, fcntl( ssl_client->socket, F_GETFL, 0 ) | O_NONBLOCK ); struct sockaddr_in serv_addr; memset(&serv_addr, 0, sizeof(serv_addr)); serv_addr.sin_family = AF_INET; serv_addr.sin_addr.s_addr = srv; serv_addr.sin_port = htons(port); - if(timeout <= 0){ - timeout = 30000; // Milli seconds. - } - - fd_set fdset; - struct timeval tv; - FD_ZERO(&fdset); - FD_SET(ssl_client->socket, &fdset); - tv.tv_sec = timeout / 1000; - tv.tv_usec = (timeout % 1000) * 1000; - - int res = lwip_connect(ssl_client->socket, (struct sockaddr*)&serv_addr, sizeof(serv_addr)); - if (res < 0 && errno != EINPROGRESS) { - log_e("connect on fd %d, errno: %d, \"%s\"", ssl_client->socket, errno, strerror(errno)); - close(ssl_client->socket); - return -1; - } - - res = select(ssl_client->socket + 1, nullptr, &fdset, nullptr, timeout<0 ? nullptr : &tv); - if (res < 0) { - log_e("select on fd %d, errno: %d, \"%s\"", ssl_client->socket, errno, strerror(errno)); - close(ssl_client->socket); - return -1; - } else if (res == 0) { - log_i("select returned due to timeout %d ms for fd %d", timeout, ssl_client->socket); - close(ssl_client->socket); - return -1; - } else { - int sockerr; - socklen_t len = (socklen_t)sizeof(int); - res = getsockopt(ssl_client->socket, SOL_SOCKET, SO_ERROR, &sockerr, &len); - - if (res < 0) { - log_e("getsockopt on fd %d, errno: %d, \"%s\"", ssl_client->socket, errno, strerror(errno)); - close(ssl_client->socket); - return -1; - } - - if (sockerr != 0) { - log_e("socket error on fd %d, errno: %d, \"%s\"", ssl_client->socket, sockerr, strerror(sockerr)); - close(ssl_client->socket); - return -1; + if (lwip_connect(ssl_client->socket, (struct sockaddr *)&serv_addr, sizeof(serv_addr)) == 0) { + if(timeout <= 0){ + timeout = 30000; } + lwip_setsockopt(ssl_client->socket, SOL_SOCKET, SO_RCVTIMEO, &timeout, sizeof(timeout)); + lwip_setsockopt(ssl_client->socket, SOL_SOCKET, SO_SNDTIMEO, &timeout, sizeof(timeout)); + lwip_setsockopt(ssl_client->socket, IPPROTO_TCP, TCP_NODELAY, &enable, sizeof(enable)); + lwip_setsockopt(ssl_client->socket, SOL_SOCKET, SO_KEEPALIVE, &enable, sizeof(enable)); + } else { + log_e("Connect to Server failed!"); + return -1; } - -#define ROE(x,msg) { if (((x)<0)) { log_e("LWIP Socket config of " msg " failed."); return -1; }} - ROE(lwip_setsockopt(ssl_client->socket, SOL_SOCKET, SO_RCVTIMEO, &tv, sizeof(tv)),"SO_RCVTIMEO"); - ROE(lwip_setsockopt(ssl_client->socket, SOL_SOCKET, SO_SNDTIMEO, &tv, sizeof(tv)),"SO_SNDTIMEO"); - - ROE(lwip_setsockopt(ssl_client->socket, IPPROTO_TCP, TCP_NODELAY, &enable, sizeof(enable)),"TCP_NODELAY"); - ROE(lwip_setsockopt(ssl_client->socket, SOL_SOCKET, SO_KEEPALIVE, &enable, sizeof(enable)),"SO_KEEPALIVE"); - - + fcntl( ssl_client->socket, F_SETFL, fcntl( ssl_client->socket, F_GETFL, 0 ) | O_NONBLOCK ); log_v("Seeding the random number generator"); mbedtls_entropy_init(&ssl_client->entropy_ctx); @@ -156,20 +110,10 @@ int start_ssl_client(sslclient_context *ssl_client, const char *host, uint32_t p return handle_error(ret); } - if (alpn_protos != NULL) { - log_v("Setting ALPN protocols"); - if ((ret = mbedtls_ssl_conf_alpn_protocols(&ssl_client->ssl_conf, alpn_protos) ) != 0) { - return handle_error(ret); - } - } - // MBEDTLS_SSL_VERIFY_REQUIRED if a CA certificate is defined on Arduino IDE and // MBEDTLS_SSL_VERIFY_NONE if not. - if (insecure) { - mbedtls_ssl_conf_authmode(&ssl_client->ssl_conf, MBEDTLS_SSL_VERIFY_NONE); - log_i("WARNING: Skipping SSL Verification. INSECURE!"); - } else if (rootCABuff != NULL) { + if (rootCABuff != NULL) { log_v("Loading CA cert"); mbedtls_x509_crt_init(&ssl_client->ca_cert); mbedtls_ssl_conf_authmode(&ssl_client->ssl_conf, MBEDTLS_SSL_VERIFY_REQUIRED); @@ -177,8 +121,6 @@ int start_ssl_client(sslclient_context *ssl_client, const char *host, uint32_t p mbedtls_ssl_conf_ca_chain(&ssl_client->ssl_conf, &ssl_client->ca_cert, NULL); //mbedtls_ssl_conf_verify(&ssl_client->ssl_ctx, my_verify, NULL ); if (ret < 0) { - // free the ca_cert in the case parse failed, otherwise, the old ca_cert still in the heap memory, that lead to "out of memory" crash. - mbedtls_x509_crt_free(&ssl_client->ca_cert); return handle_error(ret); } } else if (pskIdent != NULL && psKey != NULL) { @@ -212,10 +154,11 @@ int start_ssl_client(sslclient_context *ssl_client, const char *host, uint32_t p return handle_error(ret); } } else { - return -1; + mbedtls_ssl_conf_authmode(&ssl_client->ssl_conf, MBEDTLS_SSL_VERIFY_NONE); + log_i("WARNING: Use certificates for a more secure communication!"); } - if (!insecure && cli_cert != NULL && cli_key != NULL) { + if (cli_cert != NULL && cli_key != NULL) { mbedtls_x509_crt_init(&ssl_client->client_cert); mbedtls_pk_init(&ssl_client->client_key); @@ -223,8 +166,6 @@ int start_ssl_client(sslclient_context *ssl_client, const char *host, uint32_t p ret = mbedtls_x509_crt_parse(&ssl_client->client_cert, (const unsigned char *)cli_cert, strlen(cli_cert) + 1); if (ret < 0) { - // free the client_cert in the case parse failed, otherwise, the old client_cert still in the heap memory, that lead to "out of memory" crash. - mbedtls_x509_crt_free(&ssl_client->client_cert); return handle_error(ret); } @@ -232,6 +173,7 @@ int start_ssl_client(sslclient_context *ssl_client, const char *host, uint32_t p ret = mbedtls_pk_parse_key(&ssl_client->client_key, (const unsigned char *)cli_key, strlen(cli_key) + 1, NULL, 0); if (ret != 0) { + mbedtls_x509_crt_free(&ssl_client->client_cert); // cert+key are free'd in pair return handle_error(ret); } @@ -243,7 +185,7 @@ int start_ssl_client(sslclient_context *ssl_client, const char *host, uint32_t p // Hostname set here should match CN in server certificate if((ret = mbedtls_ssl_set_hostname(&ssl_client->ssl_ctx, host)) != 0){ return handle_error(ret); - } + } mbedtls_ssl_conf_rng(&ssl_client->ssl_conf, mbedtls_ctr_drbg_random, &ssl_client->drbg_ctx); @@ -257,11 +199,12 @@ int start_ssl_client(sslclient_context *ssl_client, const char *host, uint32_t p unsigned long handshake_start_time=millis(); while ((ret = mbedtls_ssl_handshake(&ssl_client->ssl_ctx)) != 0) { if (ret != MBEDTLS_ERR_SSL_WANT_READ && ret != MBEDTLS_ERR_SSL_WANT_WRITE) { + log_d("Verify result %04X", mbedtls_ssl_get_verify_result(&ssl_client->ssl_ctx)); return handle_error(ret); } if((millis()-handshake_start_time)>ssl_client->handshake_timeout) - return -1; - vTaskDelay(2);//2 ticks + return -1; + vTaskDelay(10 / portTICK_PERIOD_MS); } @@ -277,10 +220,10 @@ int start_ssl_client(sslclient_context *ssl_client, const char *host, uint32_t p log_v("Verifying peer X.509 certificate..."); if ((flags = mbedtls_ssl_get_verify_result(&ssl_client->ssl_ctx)) != 0) { - memset(buf, 0, sizeof(buf)); + bzero(buf, sizeof(buf)); mbedtls_x509_crt_verify_info(buf, sizeof(buf), " ! ", flags); log_e("Failed to verify peer certificate! verification info: %s", buf); - stop_ssl_socket(ssl_client, rootCABuff, cli_cert, cli_key); //It's not safe continue. + //It's not safe continue. return handle_error(ret); } else { log_v("Certificate verified."); @@ -312,11 +255,20 @@ void stop_ssl_socket(sslclient_context *ssl_client, const char *rootCABuff, cons close(ssl_client->socket); ssl_client->socket = -1; } - + // avoid memory leak if ssl connection attempt failed + if (ssl_client->ssl_conf.ca_chain != NULL) { + mbedtls_x509_crt_free(&ssl_client->ca_cert); + } + if (ssl_client->ssl_conf.key_cert != NULL) { + mbedtls_x509_crt_free(&ssl_client->client_cert); + mbedtls_pk_free(&ssl_client->client_key); + } mbedtls_ssl_free(&ssl_client->ssl_ctx); mbedtls_ssl_config_free(&ssl_client->ssl_conf); mbedtls_ctr_drbg_free(&ssl_client->drbg_ctx); mbedtls_entropy_free(&ssl_client->entropy_ctx); + // reset embedded pointers to zero + memset(ssl_client, 0, sizeof(sslclient_context)); } @@ -334,26 +286,27 @@ int data_to_read(sslclient_context *ssl_client) return res; } -int send_ssl_data(sslclient_context *ssl_client, const uint8_t *data, size_t len) + +int send_ssl_data(sslclient_context *ssl_client, const uint8_t *data, uint16_t len) { - log_v("Writing HTTP request with %d bytes...", len); //for low level debug + log_v("Writing SSL data..."); //for low level debug int ret = -1; while ((ret = mbedtls_ssl_write(&ssl_client->ssl_ctx, data, len)) <= 0) { - if (ret != MBEDTLS_ERR_SSL_WANT_READ && ret != MBEDTLS_ERR_SSL_WANT_WRITE && ret < 0) { - log_v("Handling error %d", ret); //for low level debug + if (ret != MBEDTLS_ERR_SSL_WANT_READ && ret != MBEDTLS_ERR_SSL_WANT_WRITE) { return handle_error(ret); } - //wait for space to become available - vTaskDelay(2); } + len = ret; + //log_v("%d bytes written", len); //for low level debug return ret; } + int get_ssl_receive(sslclient_context *ssl_client, uint8_t *data, int length) { - //log_d( "Reading HTTP response..."); //for low level debug + log_v( "Reading SSL data..."); //for low level debug int ret = -1; ret = mbedtls_ssl_read(&ssl_client->ssl_ctx, data, length); @@ -425,10 +378,22 @@ bool verify_ssl_fingerprint(sslclient_context *ssl_client, const char* fp, const fingerprint_local[i] = low | (high << 4); } + // Get certificate provided by the peer + const mbedtls_x509_crt* crt = mbedtls_ssl_get_peer_cert(&ssl_client->ssl_ctx); + + if (!crt) + { + log_d("could not fetch peer certificate"); + return false; + } + // Calculate certificate's SHA256 fingerprint uint8_t fingerprint_remote[32]; - if(!get_peer_fingerprint(ssl_client, fingerprint_remote)) - return false; + mbedtls_sha256_context sha256_ctx; + mbedtls_sha256_init(&sha256_ctx); + mbedtls_sha256_starts(&sha256_ctx, false); + mbedtls_sha256_update(&sha256_ctx, crt->raw.p, crt->raw.len); + mbedtls_sha256_finish(&sha256_ctx, fingerprint_remote); // Check if fingerprints match if (memcmp(fingerprint_local, fingerprint_remote, 32)) @@ -444,28 +409,6 @@ bool verify_ssl_fingerprint(sslclient_context *ssl_client, const char* fp, const return true; } -bool get_peer_fingerprint(sslclient_context *ssl_client, uint8_t sha256[32]) -{ - if (!ssl_client) { - log_d("Invalid ssl_client pointer"); - return false; - }; - - const mbedtls_x509_crt* crt = mbedtls_ssl_get_peer_cert(&ssl_client->ssl_ctx); - if (!crt) { - log_d("Failed to get peer cert."); - return false; - }; - - mbedtls_sha256_context sha256_ctx; - mbedtls_sha256_init(&sha256_ctx); - mbedtls_sha256_starts(&sha256_ctx, false); - mbedtls_sha256_update(&sha256_ctx, crt->raw.p, crt->raw.len); - mbedtls_sha256_finish(&sha256_ctx, sha256); - - return true; -} - // Checks if peer certificate has specified domain in CN or SANs bool verify_ssl_dn(sslclient_context *ssl_client, const char* domain_name) { @@ -513,5 +456,3 @@ bool verify_ssl_dn(sslclient_context *ssl_client, const char* domain_name) return false; } -#endif -