diff --git a/libraries/NetworkClientSecure/src/NetworkClientSecure.cpp b/libraries/NetworkClientSecure/src/NetworkClientSecure.cpp index 1ef03f29dff..f0857c32bac 100644 --- a/libraries/NetworkClientSecure/src/NetworkClientSecure.cpp +++ b/libraries/NetworkClientSecure/src/NetworkClientSecure.cpp @@ -317,9 +317,11 @@ void NetworkClientSecure::setCACert(const char *rootCA) { void NetworkClientSecure::setCACertBundle(const uint8_t *bundle) { if (bundle != NULL) { esp_crt_bundle_set(bundle, sizeof(bundle)); + attach_ssl_certificate_bundle(sslclient.get(), true); _use_ca_bundle = true; } else { esp_crt_bundle_detach(NULL); + attach_ssl_certificate_bundle(sslclient.get(), false); _use_ca_bundle = false; } } diff --git a/libraries/NetworkClientSecure/src/ssl_client.cpp b/libraries/NetworkClientSecure/src/ssl_client.cpp index 41e79ee3803..c8d5bbd21ea 100644 --- a/libraries/NetworkClientSecure/src/ssl_client.cpp +++ b/libraries/NetworkClientSecure/src/ssl_client.cpp @@ -51,6 +51,14 @@ void ssl_init(sslclient_context *ssl_client) { ssl_client->peek_buf = -1; } +void attach_ssl_certificate_bundle(sslclient_context *ssl_client, bool att) { + if (att) { + ssl_client->bundle_attach_cb = &esp_crt_bundle_attach; + } else { + ssl_client->bundle_attach_cb = NULL; + } +} + int start_ssl_client( sslclient_context *ssl_client, const IPAddress &ip, uint32_t port, const char *hostname, int timeout, const char *rootCABuff, bool useRootCABundle, const char *cli_cert, const char *cli_key, const char *pskIdent, const char *psKey, bool insecure, const char **alpn_protos @@ -195,11 +203,14 @@ int start_ssl_client( return handle_error(ret); } } else if (useRootCABundle) { - log_v("Attaching root CA cert bundle"); - ret = esp_crt_bundle_attach(&ssl_client->ssl_conf); - - if (ret < 0) { - return handle_error(ret); + if (ssl_client->bundle_attach_cb != NULL) { + log_v("Attaching root CA cert bundle"); + ret = ssl_client->bundle_attach_cb(&ssl_client->ssl_conf); + if (ret < 0) { + return handle_error(ret); + } + } else { + log_e("useRootCABundle is set, but attach_ssl_certificate_bundle(ssl, true); was not called!"); } } else if (pskIdent != NULL && psKey != NULL) { log_v("Setting up PSK"); diff --git a/libraries/NetworkClientSecure/src/ssl_client.h b/libraries/NetworkClientSecure/src/ssl_client.h index 3e07bf6bc2c..892adc86a95 100644 --- a/libraries/NetworkClientSecure/src/ssl_client.h +++ b/libraries/NetworkClientSecure/src/ssl_client.h @@ -12,6 +12,8 @@ #include "mbedtls/ctr_drbg.h" #include "mbedtls/error.h" +typedef esp_err_t (*crt_bundle_attach_cb)(void *conf); + typedef struct sslclient_context { int socket; mbedtls_ssl_context ssl_ctx; @@ -24,6 +26,8 @@ typedef struct sslclient_context { mbedtls_x509_crt client_cert; mbedtls_pk_context client_key; + crt_bundle_attach_cb bundle_attach_cb; + unsigned long socket_timeout; unsigned long handshake_timeout; @@ -37,6 +41,7 @@ int start_ssl_client( sslclient_context *ssl_client, const IPAddress &ip, uint32_t port, const char *hostname, int timeout, const char *rootCABuff, bool useRootCABundle, const char *cli_cert, const char *cli_key, const char *pskIdent, const char *psKey, bool insecure, const char **alpn_protos ); +void attach_ssl_certificate_bundle(sslclient_context *ssl_client, bool att); int ssl_starttls_handshake(sslclient_context *ssl_client); void stop_ssl_socket(sslclient_context *ssl_client); int data_to_read(sslclient_context *ssl_client);