diff --git a/libraries/WiFiClientSecure/examples/WiFiClientSecureProtocolUpgrade/.skip.esp32h2 b/libraries/WiFiClientSecure/examples/WiFiClientSecureProtocolUpgrade/.skip.esp32h2 new file mode 100644 index 00000000000..e69de29bb2d diff --git a/libraries/WiFiClientSecure/examples/WiFiClientSecureProtocolUpgrade/WiFiClientSecureProtocolUpgrade.ino b/libraries/WiFiClientSecure/examples/WiFiClientSecureProtocolUpgrade/WiFiClientSecureProtocolUpgrade.ino new file mode 100644 index 00000000000..06c19d105c4 --- /dev/null +++ b/libraries/WiFiClientSecure/examples/WiFiClientSecureProtocolUpgrade/WiFiClientSecureProtocolUpgrade.ino @@ -0,0 +1,176 @@ +/* STARTSSL example + + Inline upgrading from a clear-text connection to an SSL/TLS connection. + + Some protocols such as SMTP, XMPP, Mysql, Postgress and others allow, or require, + that you start the connection without encryption; and then send a command to switch + over to encryption. + + E.g. a typical SMTP submission would entail a dialogue such as this: + + 1. client connects to server in the clear + 2. server says hello + 3. client sents a EHLO + 4. server tells the client that it supports SSL/TLS + 5. client sends a 'STARTTLS' to make use of this faciltiy + 6. client/server negiotiate a SSL or TLS connection. + 7. client sends another EHLO + 8. server now tells the client what (else) is supported; such as additional authentication options. + ... conversation continues encrypted. + + This can be enabled in WiFiClientSecure by telling it to start in plaintext: + + client.setPlainStart(); + + and client is than a plain, TCP, connection (just as WiFiClient would be); until the client calls + the method: + + client.startTLS(); // returns zero on error; non zero on success. + + After which things switch to TLS/SSL. +*/ + +#include + +#ifndef WIFI_NETWORK +#define WIFI_NETWORK "YOUR Wifi SSID" +#endif + +#ifndef WIFI_PASSWD +#define WIFI_PASSWD "your-secret-password" +#endif + +#ifndef SMTP_HOST +#define SMTP_HOST "smtp.gmail.com" +#endif + +#ifndef SMTP_PORT +#define SMTP_PORT (587) // Standard (plaintext) submission port +#endif + +const char* ssid = WIFI_NETWORK; // your network SSID (name of wifi network) +const char* password = WIFI_PASSWD; // your network password +const char* server = SMTP_HOST; // Server URL +const int submission_port = SMTP_PORT; // submission port. + +WiFiClientSecure client; + +static bool readAllSMTPLines(); + +void setup() { + int ret; + //Initialize serial and wait for port to open: + Serial.begin(115200); + delay(100); + + Serial.print("Attempting to connect to SSID: "); + Serial.print(ssid); + WiFi.begin(ssid, password); + + // attempt to connect to Wifi network: + while (WiFi.status() != WL_CONNECTED) { + Serial.print("."); + // wait 1 second for re-trying + delay(1000); + } + + Serial.print("Connected to "); + Serial.println(ssid); + + Serial.printf("\nStarting connection to server: %s:%d\n", server, submission_port); + + + // skip verification for this demo. In production one should at the very least + // enable TOFU; or ideally hardcode a (CA) certificate that is trusted. + client.setInsecure(); + + // Enable a plain-test start. + client.setPlainStart(); + + if (!client.connect(server, SMTP_PORT)) { + Serial.println("Connection failed!"); + return; + }; + + Serial.println("Connected to server (in the clear, in plaintest)"); + + if (!readAllSMTPLines()) goto err; + + Serial.println("Sending : EHLO\t\tin the clear"); + client.print("EHLO there\r\n"); + + if (!readAllSMTPLines()) goto err; + + Serial.println("Sending : STARTTLS\t\tin the clear"); + client.print("STARTTLS\r\n"); + + if (!readAllSMTPLines()) goto err; + + Serial.println("Upgrading connection to TLS"); + if ((ret=client.startTLS()) <= 0) { + Serial.printf("Upgrade connection failed: err %d\n", ret); + goto err; + } + + Serial.println("Sending : EHLO again\t\tover the now encrypted connection"); + client.print("EHLO again\r\n"); + + if (!readAllSMTPLines()) goto err; + + // normally, as this point - we'd be authenticating and then be submitting + // an email. This has been left out of this example. + + Serial.println("Sending : QUIT\t\t\tover the now encrypted connection"); + client.print("QUIT\r\n"); + + if (!readAllSMTPLines()) goto err; + + Serial.println("Completed OK\n"); +err: + Serial.println("Closing connection"); + client.stop(); +} + +// SMTP command repsponse start with three digits and a space; +// or, for continuation, with three digits and a '-'. +static bool readAllSMTPLines() { + String s = ""; + int i; + + // blocking read; we cannot rely on a timeout + // of a WiFiClientSecure read; as it is non + // blocking. + const unsigned long timeout = 15 * 1000; + unsigned long start = millis(); // the timeout is for the entire CMD block response; not per character/line. + while (1) { + while ((i = client.available()) == 0 && millis() - start < timeout) { + /* .. wait */ + }; + if (i == 0) { + Serial.println("Timeout reading SMTP response"); + return false; + }; + if (i < 0) + break; + + i = client.read(); + if (i < 0) + break; + + if (i > 31 && i < 128) s += (char)i; + if (i == 0x0A) { + Serial.print("Receiving: "); + Serial.println(s); + if (s.charAt(3) == ' ') + return true; + s = ""; + } + } + Serial.printf("Error reading SMTP command response line: %d\n", i); + return false; +} + +void loop() { + // do nothing +} + diff --git a/libraries/WiFiClientSecure/src/WiFiClientSecure.cpp b/libraries/WiFiClientSecure/src/WiFiClientSecure.cpp index 2f9da58f9ad..ebecf94ffb6 100644 --- a/libraries/WiFiClientSecure/src/WiFiClientSecure.cpp +++ b/libraries/WiFiClientSecure/src/WiFiClientSecure.cpp @@ -143,9 +143,16 @@ int WiFiClientSecure::connect(const char *host, uint16_t port, const char *CA_ce int WiFiClientSecure::connect(IPAddress ip, uint16_t port, const char *host, const char *CA_cert, const char *cert, const char *private_key) { int ret = start_ssl_client(sslclient, ip, port, host, _timeout, CA_cert, _use_ca_bundle, cert, private_key, NULL, NULL, _use_insecure, _alpn_protos); + + if (ret >=0 && ! _stillinPlainStart) + ret = ssl_starttls_handshake(sslclient); + else + log_i("Actual TLS start posponed."); + _lastError = ret; + if (ret < 0) { - log_e("start_ssl_client: %d", ret); + log_e("start_ssl_client: connect failed: %d", ret); stop(); return 0; } @@ -153,6 +160,23 @@ int WiFiClientSecure::connect(IPAddress ip, uint16_t port, const char *host, con return 1; } +int WiFiClientSecure::startTLS() +{ + int ret = 1; + if (_stillinPlainStart) { + log_i("startTLS: starting TLS/SSL on this dplain connection"); + ret = ssl_starttls_handshake(sslclient); + if (ret < 0) { + log_e("startTLS: %d", ret); + stop(); + return 0; + }; + _stillinPlainStart = false; + } else + log_i("startTLS: ignoring StartTLS - as we should be secure already"); + return 1; +} + int WiFiClientSecure::connect(IPAddress ip, uint16_t port, const char *pskIdent, const char *psKey) { return connect(ip.toString().c_str(), port, pskIdent, psKey); } @@ -167,7 +191,7 @@ int WiFiClientSecure::connect(const char *host, uint16_t port, const char *pskId int ret = start_ssl_client(sslclient, address, port, host, _timeout, NULL, false, NULL, NULL, pskIdent, psKey, _use_insecure, _alpn_protos); _lastError = ret; if (ret < 0) { - log_e("start_ssl_client: %d", ret); + log_e("start_ssl_client: connect failed %d", ret); stop(); return 0; } @@ -192,10 +216,7 @@ int WiFiClientSecure::read() { uint8_t data = -1; int res = read(&data, 1); - if (res < 0) { - return res; - } - return data; + return res < 0 ? res: data; } size_t WiFiClientSecure::write(const uint8_t *buf, size_t size) @@ -203,6 +224,10 @@ size_t WiFiClientSecure::write(const uint8_t *buf, size_t size) if (!_connected) { return 0; } + + if (_stillinPlainStart) + return send_net_data(sslclient, buf, size); + if(_lastWriteTimeout != _timeout){ struct timeval timeout_tv; timeout_tv.tv_sec = _timeout / 1000; @@ -212,9 +237,9 @@ size_t WiFiClientSecure::write(const uint8_t *buf, size_t size) _lastWriteTimeout = _timeout; } } - int res = send_ssl_data(sslclient, buf, size); if (res < 0) { + log_e("Closing connection on failed write"); stop(); res = 0; } @@ -223,6 +248,9 @@ size_t WiFiClientSecure::write(const uint8_t *buf, size_t size) int WiFiClientSecure::read(uint8_t *buf, size_t size) { + if(_stillinPlainStart) + return get_net_receive(sslclient, buf, size); + if(_lastReadTimeout != _timeout){ if(fd() >= 0){ struct timeval timeout_tv; @@ -235,7 +263,7 @@ int WiFiClientSecure::read(uint8_t *buf, size_t size) } } - int peeked = 0; + int peeked = 0, res = -1; int avail = available(); if ((!buf && size) || avail <= 0) { return -1; @@ -254,9 +282,10 @@ int WiFiClientSecure::read(uint8_t *buf, size_t size) buf++; peeked = 1; } - - int res = get_ssl_receive(sslclient, buf, size); + res = get_ssl_receive(sslclient, buf, size); + if (res < 0) { + log_e("Closing connection on failed read"); stop(); return peeked?peeked:res; } @@ -265,12 +294,17 @@ int WiFiClientSecure::read(uint8_t *buf, size_t size) int WiFiClientSecure::available() { - int peeked = (_peek >= 0); + if (_stillinPlainStart) + return peek_net_receive(sslclient,0); + + int peeked = (_peek >= 0), res = -1; if (!_connected) { return peeked; } - int res = data_to_read(sslclient); - if (res < 0) { + res = data_to_read(sslclient); + + if (res < 0 && !_stillinPlainStart) { + log_e("Closing connection on failed available check"); stop(); return peeked?peeked:res; } @@ -406,3 +440,4 @@ int WiFiClientSecure::fd() const { return sslclient->socket; } + diff --git a/libraries/WiFiClientSecure/src/WiFiClientSecure.h b/libraries/WiFiClientSecure/src/WiFiClientSecure.h index 8c130f450cc..31daebbdb42 100644 --- a/libraries/WiFiClientSecure/src/WiFiClientSecure.h +++ b/libraries/WiFiClientSecure/src/WiFiClientSecure.h @@ -34,6 +34,7 @@ class WiFiClientSecure : public WiFiClient int _peek = -1; int _timeout; bool _use_insecure; + bool _stillinPlainStart = false; const char *_CA_cert; const char *_cert; const char *_private_key; @@ -78,6 +79,17 @@ class WiFiClientSecure : public WiFiClient bool verify(const char* fingerprint, const char* domain_name); void setHandshakeTimeout(unsigned long handshake_timeout); void setAlpnProtocols(const char **alpn_protos); + + // Certain protocols start in plain-text; and then have the client + // give some STARTSSL command to `upgrade' the connection to TLS + // or SSL. Setting PlainStart to true (the default is false) enables + // this. It is up to the application code to then call 'startTLS()' + // at the right point to initialise the SSL or TLS upgrade. + + void setPlainStart() { _stillinPlainStart = true; }; + bool stillInPlainStart() { return _stillinPlainStart; }; + int startTLS(); + const mbedtls_x509_crt* getPeerCertificate() { return mbedtls_ssl_get_peer_cert(&sslclient->ssl_ctx); }; bool getFingerprintSHA256(uint8_t sha256_result[32]) { return get_peer_fingerprint(sslclient, sha256_result); }; int fd() const; diff --git a/libraries/WiFiClientSecure/src/ssl_client.cpp b/libraries/WiFiClientSecure/src/ssl_client.cpp index 8dcf05877ba..a4308abb1e8 100644 --- a/libraries/WiFiClientSecure/src/ssl_client.cpp +++ b/libraries/WiFiClientSecure/src/ssl_client.cpp @@ -55,8 +55,7 @@ void ssl_init(sslclient_context *ssl_client) int start_ssl_client(sslclient_context *ssl_client, const IPAddress& ip, uint32_t port, const char* hostname, int timeout, const char *rootCABuff, bool useRootCABundle, const char *cli_cert, const char *cli_key, const char *pskIdent, const char *psKey, bool insecure, const char **alpn_protos) { - char buf[512]; - int ret, flags; + int ret; int enable = 1; log_v("Free internal heap before TLS %u", ESP.getFreeHeap()); @@ -226,6 +225,9 @@ int start_ssl_client(sslclient_context *ssl_client, const IPAddress& ip, uint32_ return -1; } + // Note - this check for BOTH key and cert is relied on + // later during cleanup. + if (!insecure && cli_cert != NULL && cli_key != NULL) { mbedtls_x509_crt_init(&ssl_client->client_cert); mbedtls_pk_init(&ssl_client->client_key); @@ -267,6 +269,13 @@ int start_ssl_client(sslclient_context *ssl_client, const IPAddress& ip, uint32_ } mbedtls_ssl_set_bio(&ssl_client->ssl_ctx, &ssl_client->socket, mbedtls_net_send, mbedtls_net_recv, NULL ); + return ssl_client->socket; +} + +int ssl_starttls_handshake(sslclient_context *ssl_client) +{ + char buf[512]; + int ret, flags; log_v("Performing the SSL/TLS handshake..."); unsigned long handshake_start_time=millis(); @@ -280,7 +289,7 @@ int start_ssl_client(sslclient_context *ssl_client, const IPAddress& ip, uint32_ } - if (cli_cert != NULL && cli_key != NULL) { + if (ssl_client->client_cert.version) { log_d("Protocol is %s Ciphersuite is %s", mbedtls_ssl_get_version(&ssl_client->ssl_ctx), mbedtls_ssl_get_ciphersuite(&ssl_client->ssl_ctx)); if ((ret = mbedtls_ssl_get_record_expansion(&ssl_client->ssl_ctx)) >= 0) { log_d("Record expansion is %d", ret); @@ -300,15 +309,16 @@ int start_ssl_client(sslclient_context *ssl_client, const IPAddress& ip, uint32_ log_v("Certificate verified."); } - if (rootCABuff != NULL) { + if (ssl_client->ca_cert.version) { mbedtls_x509_crt_free(&ssl_client->ca_cert); } - if (cli_cert != NULL) { + // We know that we always have a client cert/key pair -- and we + // cannot look into the private client_key pk struct for newer + // versions of mbedtls. So rely on a public field of the cert + // and infer that there is a key too. + if (ssl_client->client_cert.version) { mbedtls_x509_crt_free(&ssl_client->client_cert); - } - - if (cli_key != NULL) { mbedtls_pk_free(&ssl_client->client_key); } @@ -317,7 +327,6 @@ int start_ssl_client(sslclient_context *ssl_client, const IPAddress& ip, uint32_ return ssl_client->socket; } - void stop_ssl_socket(sslclient_context *ssl_client, const char *rootCABuff, const char *cli_cert, const char *cli_key) { log_v("Cleaning SSL connection."); @@ -328,13 +337,13 @@ void stop_ssl_socket(sslclient_context *ssl_client, const char *rootCABuff, cons } // avoid memory leak if ssl connection attempt failed - //if (ssl_client->ssl_conf.ca_chain != NULL) { + // if (ssl_client->ssl_conf.ca_chain != NULL) { mbedtls_x509_crt_free(&ssl_client->ca_cert); - //} - //if (ssl_client->ssl_conf.key_cert != NULL) { + // } + // if (ssl_client->ssl_conf.key_cert != NULL) { mbedtls_x509_crt_free(&ssl_client->client_cert); mbedtls_pk_free(&ssl_client->client_key); - //} + // } mbedtls_ssl_free(&ssl_client->ssl_ctx); mbedtls_ssl_config_free(&ssl_client->ssl_conf); mbedtls_ctr_drbg_free(&ssl_client->drbg_ctx); @@ -368,10 +377,8 @@ int data_to_read(sslclient_context *ssl_client) int send_ssl_data(sslclient_context *ssl_client, const uint8_t *data, size_t len) { - log_v("Writing HTTP request with %d bytes...", len); //for low level debug - int ret = -1; - unsigned long write_start_time=millis(); + int ret = -1; while ((ret = mbedtls_ssl_write(&ssl_client->ssl_ctx, data, len)) <= 0) { if((millis()-write_start_time)>ssl_client->socket_timeout) { @@ -391,14 +398,60 @@ int send_ssl_data(sslclient_context *ssl_client, const uint8_t *data, size_t len return ret; } -int get_ssl_receive(sslclient_context *ssl_client, uint8_t *data, int length) +// Some protocols, such as SMTP, XMPP, MySQL/Posgress and various others +// do a 'in-line' upgrade from plaintext to SSL or TLS (usually with some +// sort of 'STARTTLS' textual command from client to sever). For this +// we need to have access to the 'raw' socket; i.e. without TLS/SSL state +// handling before the handshake starts; but after setting up the TLS +// connection. +// +int peek_net_receive(sslclient_context *ssl_client, int timeout) { +#if MBEDTLS_FIXED_LINKING_NET_POLL + int ret = mbedtls_net_poll((mbedtls_net_context*)ssl_client, MBEDTLS_NET_POLL_READ, timeout); + ret == MBEDTLS_NET_POLL_READ ? 1 : ret; +#else + // We should be using mbedtls_net_poll(); which is part of mbedtls and + // included in the EspressifSDK. Unfortunately - it did not make it into + // the statically linked library file. So, for now, we replace it by + // substancially similar code. + // + struct timeval tv = { .tv_sec = timeout / 1000, .tv_usec = (timeout % 1000) * 1000 }; + + fd_set fdset; + FD_SET(ssl_client->socket, &fdset); + + int ret = select(ssl_client->socket + 1, &fdset, nullptr, nullptr, timeout<0 ? nullptr : &tv); + if (ret < 0) { + log_e("select on read fd %d, errno: %d, \"%s\"", ssl_client->socket, errno, strerror(errno)); + lwip_close(ssl_client->socket); + ssl_client->socket = -1; + return -1; + }; +#endif + return ret; +}; + +int get_net_receive(sslclient_context *ssl_client, uint8_t *data, int length) { - //log_d( "Reading HTTP response..."); //for low level debug - int ret = -1; + int ret = peek_net_receive(ssl_client,ssl_client->socket_timeout); + if (ret > 0) + ret = mbedtls_net_recv(ssl_client, data, length); - ret = mbedtls_ssl_read(&ssl_client->ssl_ctx, data, length); + // log_v( "%d bytes NET read of %d", ret, length); //for low level debug + return ret; +} - //log_v( "%d bytes read", ret); //for low level debug +int send_net_data(sslclient_context *ssl_client, const uint8_t *data, size_t len) { + int ret = mbedtls_net_send(ssl_client, data, len); + // log_v("Net sending %d btes->ret %d", len, ret); //for low level debug + return ret; +} + + +int get_ssl_receive(sslclient_context *ssl_client, uint8_t *data, int length) +{ + int ret = mbedtls_ssl_read(&ssl_client->ssl_ctx, data, length); + // log_v( "%d bytes SSL read", ret); //for low level debug return ret; } diff --git a/libraries/WiFiClientSecure/src/ssl_client.h b/libraries/WiFiClientSecure/src/ssl_client.h index 69e49707cba..f42ad534139 100644 --- a/libraries/WiFiClientSecure/src/ssl_client.h +++ b/libraries/WiFiClientSecure/src/ssl_client.h @@ -31,10 +31,14 @@ typedef struct sslclient_context { void ssl_init(sslclient_context *ssl_client); int start_ssl_client(sslclient_context *ssl_client, const IPAddress& ip, uint32_t port, const char* hostname, int timeout, const char *rootCABuff, bool useRootCABundle, const char *cli_cert, const char *cli_key, const char *pskIdent, const char *psKey, bool insecure, const char **alpn_protos); +int ssl_starttls_handshake(sslclient_context *ssl_client); void stop_ssl_socket(sslclient_context *ssl_client, const char *rootCABuff, const char *cli_cert, const char *cli_key); int data_to_read(sslclient_context *ssl_client); int send_ssl_data(sslclient_context *ssl_client, const uint8_t *data, size_t len); int get_ssl_receive(sslclient_context *ssl_client, uint8_t *data, int length); +int send_net_data(sslclient_context *ssl_client, const uint8_t *data, size_t len); +int get_net_receive(sslclient_context *ssl_client, uint8_t *data, int length); +int peek_net_receive(sslclient_context *ssl_client, int timeout); bool verify_ssl_fingerprint(sslclient_context *ssl_client, const char* fp, const char* domain_name); bool verify_ssl_dn(sslclient_context *ssl_client, const char* domain_name); bool get_peer_fingerprint(sslclient_context *ssl_client, uint8_t sha256[32]);