Skip to content

Commit bf5a0f2

Browse files
earlephilhowerdevyte
authored andcommitted
Fix mem leak in SSL server, allow for concurrent client and server connections w/o interference (#4305)
* Fix leak on multiple SSL server connections Fixes #4302 The refcnt setup for the WiFiClientSecure's SSLContext and ClientContext had issues in certain conditions, causing a massive memory leak on each SSL server connection. Depending on the state of the machine, after two or three connections it would OOM and crash. This patch replaces most of the refcnt operations with C++11 shared_ptr operations, cleaning up the code substantially and removing the leakage. Also fixes a race condition where ClientContext was free'd before the SSLContext was stopped/shutdown. When the SSLContext tried to do ssl_free, axtls would attempt to send out the real SSL disconnect bits over the wire, however by this time the ClientContext is invalid and it would fault. * Separate client and server SSL_CTX, support both Refactor to use a separate client SSL_CTX and server SSL_CTX. This allows for separate certificates to be installed on each, and means that you can now have both a *single* client and a *single* server running in parallel at the same time, as they'll have separate memory areas. Tested using mqtt_esp8266 SSL client with a client certificate and a WebServerSecure with its own custom certificate and key in parallel. * Add brackets around a couple if-else clauses
1 parent cda72a0 commit bf5a0f2

File tree

2 files changed

+97
-90
lines changed

2 files changed

+97
-90
lines changed

libraries/ESP8266WiFi/src/WiFiClientSecure.cpp

+96-87
Original file line numberDiff line numberDiff line change
@@ -74,37 +74,47 @@ typedef std::list<BufferItem> BufferList;
7474
class SSLContext
7575
{
7676
public:
77-
SSLContext()
77+
SSLContext(bool isServer = false)
7878
{
79-
if (_ssl_ctx_refcnt == 0) {
80-
_ssl_ctx = ssl_ctx_new(SSL_SERVER_VERIFY_LATER | SSL_DEBUG_OPTS | SSL_CONNECT_IN_PARTS | SSL_READ_BLOCKING | SSL_NO_DEFAULT_KEY, 0);
79+
_isServer = isServer;
80+
if (!_isServer) {
81+
if (_ssl_client_ctx_refcnt == 0) {
82+
_ssl_client_ctx = ssl_ctx_new(SSL_SERVER_VERIFY_LATER | SSL_DEBUG_OPTS | SSL_CONNECT_IN_PARTS | SSL_READ_BLOCKING | SSL_NO_DEFAULT_KEY, 0);
83+
}
84+
++_ssl_client_ctx_refcnt;
85+
} else {
86+
if (_ssl_svr_ctx_refcnt == 0) {
87+
_ssl_svr_ctx = ssl_ctx_new(SSL_SERVER_VERIFY_LATER | SSL_DEBUG_OPTS | SSL_CONNECT_IN_PARTS | SSL_READ_BLOCKING | SSL_NO_DEFAULT_KEY, 0);
88+
}
89+
++_ssl_svr_ctx_refcnt;
8190
}
82-
++_ssl_ctx_refcnt;
8391
}
8492

8593
~SSLContext()
8694
{
87-
if (_ssl) {
88-
ssl_free(_ssl);
89-
_ssl = nullptr;
95+
if (io_ctx) {
96+
io_ctx->unref();
97+
io_ctx = nullptr;
9098
}
91-
92-
--_ssl_ctx_refcnt;
93-
if (_ssl_ctx_refcnt == 0) {
94-
ssl_ctx_free(_ssl_ctx);
99+
_ssl = nullptr;
100+
if (!_isServer) {
101+
--_ssl_client_ctx_refcnt;
102+
if (_ssl_client_ctx_refcnt == 0) {
103+
ssl_ctx_free(_ssl_client_ctx);
104+
_ssl_client_ctx = nullptr;
105+
}
106+
} else {
107+
--_ssl_svr_ctx_refcnt;
108+
if (_ssl_svr_ctx_refcnt == 0) {
109+
ssl_ctx_free(_ssl_svr_ctx);
110+
_ssl_svr_ctx = nullptr;
111+
}
95112
}
96113
}
97114

98-
void ref()
99-
{
100-
++_refcnt;
101-
}
102-
103-
void unref()
115+
static void _delete_shared_SSL(SSL *_to_del)
104116
{
105-
if (--_refcnt == 0) {
106-
delete this;
107-
}
117+
ssl_free(_to_del);
108118
}
109119

110120
void connect(ClientContext* ctx, const char* hostName, uint32_t timeout_ms)
@@ -116,50 +126,67 @@ class SSLContext
116126
ssl_free will want to send a close notify alert, but the old TCP connection
117127
is already gone at this point, so reset io_ctx. */
118128
io_ctx = nullptr;
119-
ssl_free(_ssl);
129+
_ssl = nullptr;
120130
_available = 0;
121131
_read_ptr = nullptr;
122132
}
123133
io_ctx = ctx;
124-
_ssl = ssl_client_new(_ssl_ctx, reinterpret_cast<int>(this), nullptr, 0, ext);
134+
ctx->ref();
135+
136+
// Wrap the new SSL with a smart pointer, custom deleter to call ssl_free
137+
SSL *_new_ssl = ssl_client_new(_ssl_client_ctx, reinterpret_cast<int>(this), nullptr, 0, ext);
138+
std::shared_ptr<SSL> _new_ssl_shared(_new_ssl, _delete_shared_SSL);
139+
_ssl = _new_ssl_shared;
140+
125141
uint32_t t = millis();
126142

127-
while (millis() - t < timeout_ms && ssl_handshake_status(_ssl) != SSL_OK) {
143+
while (millis() - t < timeout_ms && ssl_handshake_status(_ssl.get()) != SSL_OK) {
128144
uint8_t* data;
129-
int rc = ssl_read(_ssl, &data);
145+
int rc = ssl_read(_ssl.get(), &data);
130146
if (rc < SSL_OK) {
131147
ssl_display_error(rc);
132148
break;
133149
}
134150
}
135151
}
136152

137-
void connectServer(ClientContext *ctx) {
153+
void connectServer(ClientContext *ctx, uint32_t timeout_ms)
154+
{
138155
io_ctx = ctx;
139-
_ssl = ssl_server_new(_ssl_ctx, reinterpret_cast<int>(this));
140-
_isServer = true;
156+
ctx->ref();
157+
158+
// Wrap the new SSL with a smart pointer, custom deleter to call ssl_free
159+
SSL *_new_ssl = ssl_server_new(_ssl_svr_ctx, reinterpret_cast<int>(this));
160+
std::shared_ptr<SSL> _new_ssl_shared(_new_ssl, _delete_shared_SSL);
161+
_ssl = _new_ssl_shared;
141162

142-
uint32_t timeout_ms = 5000;
143163
uint32_t t = millis();
144164

145-
while (millis() - t < timeout_ms && ssl_handshake_status(_ssl) != SSL_OK) {
165+
while (millis() - t < timeout_ms && ssl_handshake_status(_ssl.get()) != SSL_OK) {
146166
uint8_t* data;
147-
int rc = ssl_read(_ssl, &data);
167+
int rc = ssl_read(_ssl.get(), &data);
148168
if (rc < SSL_OK) {
169+
ssl_display_error(rc);
149170
break;
150171
}
151172
}
152173
}
153174

154175
void stop()
155176
{
177+
if (io_ctx) {
178+
io_ctx->unref();
179+
}
156180
io_ctx = nullptr;
157181
}
158182

159183
bool connected()
160184
{
161-
if (_isServer) return _ssl != nullptr;
162-
else return _ssl != nullptr && ssl_handshake_status(_ssl) == SSL_OK;
185+
if (_isServer) {
186+
return _ssl != nullptr;
187+
} else {
188+
return _ssl != nullptr && ssl_handshake_status(_ssl.get()) == SSL_OK;
189+
}
163190
}
164191

165192
int read(uint8_t* dst, size_t size)
@@ -289,10 +316,9 @@ class SSLContext
289316
return loadObject(type, buf.get(), size);
290317
}
291318

292-
293319
bool loadObject(int type, const uint8_t* data, size_t size)
294320
{
295-
int rc = ssl_obj_memory_load(_ssl_ctx, type, data, static_cast<int>(size), nullptr);
321+
int rc = ssl_obj_memory_load(_isServer?_ssl_svr_ctx:_ssl_client_ctx, type, data, static_cast<int>(size), nullptr);
296322
if (rc != SSL_OK) {
297323
DEBUGV("loadObject: ssl_obj_memory_load returned %d\n", rc);
298324
return false;
@@ -302,7 +328,7 @@ class SSLContext
302328

303329
bool verifyCert()
304330
{
305-
int rc = ssl_verify_cert(_ssl);
331+
int rc = ssl_verify_cert(_ssl.get());
306332
if (_allowSelfSignedCerts && rc == SSL_X509_ERROR(X509_VFY_ERROR_SELF_SIGNED)) {
307333
DEBUGV("Allowing self-signed certificate\n");
308334
return true;
@@ -321,12 +347,16 @@ class SSLContext
321347

322348
operator SSL*()
323349
{
324-
return _ssl;
350+
return _ssl.get();
325351
}
326352

327353
static ClientContext* getIOContext(int fd)
328354
{
329-
return reinterpret_cast<SSLContext*>(fd)->io_ctx;
355+
if (fd) {
356+
SSLContext *thisSSL = reinterpret_cast<SSLContext*>(fd);
357+
return thisSSL->io_ctx;
358+
}
359+
return nullptr;
330360
}
331361

332362
protected:
@@ -339,10 +369,9 @@ class SSLContext
339369
optimistic_yield(100);
340370

341371
uint8_t* data;
342-
int rc = ssl_read(_ssl, &data);
372+
int rc = ssl_read(_ssl.get(), &data);
343373
if (rc <= 0) {
344374
if (rc < SSL_OK && rc != SSL_CLOSE_NOTIFY && rc != SSL_ERROR_CONN_LOST) {
345-
ssl_free(_ssl);
346375
_ssl = nullptr;
347376
}
348377
return 0;
@@ -359,7 +388,7 @@ class SSLContext
359388
return 0;
360389
}
361390

362-
int rc = ssl_write(_ssl, src, size);
391+
int rc = ssl_write(_ssl.get(), src, size);
363392
if (rc >= 0) {
364393
return rc;
365394
}
@@ -404,19 +433,22 @@ class SSLContext
404433
}
405434

406435
bool _isServer = false;
407-
static SSL_CTX* _ssl_ctx;
408-
static int _ssl_ctx_refcnt;
409-
SSL* _ssl = nullptr;
410-
int _refcnt = 0;
436+
static SSL_CTX* _ssl_client_ctx;
437+
static int _ssl_client_ctx_refcnt;
438+
static SSL_CTX* _ssl_svr_ctx;
439+
static int _ssl_svr_ctx_refcnt;
440+
std::shared_ptr<SSL> _ssl = nullptr;
411441
const uint8_t* _read_ptr = nullptr;
412442
size_t _available = 0;
413443
BufferList _writeBuffers;
414444
bool _allowSelfSignedCerts = false;
415445
ClientContext* io_ctx = nullptr;
416446
};
417447

418-
SSL_CTX* SSLContext::_ssl_ctx = nullptr;
419-
int SSLContext::_ssl_ctx_refcnt = 0;
448+
SSL_CTX* SSLContext::_ssl_client_ctx = nullptr;
449+
int SSLContext::_ssl_client_ctx_refcnt = 0;
450+
SSL_CTX* SSLContext::_ssl_svr_ctx = nullptr;
451+
int SSLContext::_ssl_svr_ctx_refcnt = 0;
420452

421453
WiFiClientSecure::WiFiClientSecure()
422454
{
@@ -426,41 +458,25 @@ WiFiClientSecure::WiFiClientSecure()
426458

427459
WiFiClientSecure::~WiFiClientSecure()
428460
{
429-
if (_ssl) {
430-
_ssl->unref();
431-
}
432-
}
433-
434-
WiFiClientSecure::WiFiClientSecure(const WiFiClientSecure& other)
435-
: WiFiClient(static_cast<const WiFiClient&>(other))
436-
{
437-
_ssl = other._ssl;
438-
if (_ssl) {
439-
_ssl->ref();
440-
}
441-
}
442-
443-
WiFiClientSecure& WiFiClientSecure::operator=(const WiFiClientSecure& rhs)
444-
{
445-
(WiFiClient&) *this = rhs;
446-
_ssl = rhs._ssl;
447-
if (_ssl) {
448-
_ssl->ref();
449-
}
450-
return *this;
461+
_ssl = nullptr;
451462
}
452463

453464
// Only called by the WifiServerSecure, need to get the keys/certs loaded before beginning
454-
WiFiClientSecure::WiFiClientSecure(ClientContext* client, bool usePMEM, const uint8_t *rsakey, int rsakeyLen, const uint8_t *cert, int certLen)
465+
WiFiClientSecure::WiFiClientSecure(ClientContext* client, bool usePMEM,
466+
const uint8_t *rsakey, int rsakeyLen,
467+
const uint8_t *cert, int certLen)
455468
{
469+
// TLS handshake may take more than the 5 second default timeout
470+
_timeout = 15000;
471+
472+
// We've been given the client context from the available() call
456473
_client = client;
457-
if (_ssl) {
458-
_ssl->unref();
459-
_ssl = nullptr;
460-
}
474+
_client->ref();
461475

462-
_ssl = new SSLContext;
463-
_ssl->ref();
476+
// Make the "_ssl" SSLContext, in the constructor there should be none yet
477+
SSLContext *_new_ssl = new SSLContext(true);
478+
std::shared_ptr<SSLContext> _new_ssl_shared(_new_ssl);
479+
_ssl = _new_ssl_shared;
464480

465481
if (usePMEM) {
466482
if (rsakey && rsakeyLen) {
@@ -477,8 +493,7 @@ WiFiClientSecure::WiFiClientSecure(ClientContext* client, bool usePMEM, const ui
477493
_ssl->loadObject(SSL_OBJ_X509_CERT, cert, certLen);
478494
}
479495
}
480-
_client->ref();
481-
_ssl->connectServer(client);
496+
_ssl->connectServer(client, _timeout);
482497
}
483498

484499
int WiFiClientSecure::connect(IPAddress ip, uint16_t port)
@@ -510,14 +525,12 @@ int WiFiClientSecure::connect(const String host, uint16_t port)
510525
int WiFiClientSecure::_connectSSL(const char* hostName)
511526
{
512527
if (!_ssl) {
513-
_ssl = new SSLContext;
514-
_ssl->ref();
528+
_ssl = std::make_shared<SSLContext>();
515529
}
516530
_ssl->connect(_client, hostName, _timeout);
517531

518532
auto status = ssl_handshake_status(*_ssl);
519533
if (status != SSL_OK) {
520-
_ssl->unref();
521534
_ssl = nullptr;
522535
return 0;
523536
}
@@ -537,7 +550,6 @@ size_t WiFiClientSecure::write(const uint8_t *buf, size_t size)
537550
}
538551

539552
if (rc != SSL_CLOSE_NOTIFY) {
540-
_ssl->unref();
541553
_ssl = nullptr;
542554
}
543555

@@ -640,8 +652,6 @@ void WiFiClientSecure::stop()
640652
{
641653
if (_ssl) {
642654
_ssl->stop();
643-
_ssl->unref();
644-
_ssl = nullptr;
645655
}
646656
WiFiClient::stop();
647657
}
@@ -723,9 +733,9 @@ bool WiFiClientSecure::_verifyDN(const char* domain_name)
723733
String domain_name_str(domain_name);
724734
domain_name_str.toLowerCase();
725735

726-
const char* san = NULL;
736+
const char* san = nullptr;
727737
int i = 0;
728-
while ((san = ssl_get_cert_subject_alt_dnsname(*_ssl, i)) != NULL) {
738+
while ((san = ssl_get_cert_subject_alt_dnsname(*_ssl, i)) != nullptr) {
729739
String san_str(san);
730740
san_str.toLowerCase();
731741
if (matchName(san_str, domain_name_str)) {
@@ -759,8 +769,7 @@ bool WiFiClientSecure::verifyCertChain(const char* domain_name)
759769
void WiFiClientSecure::_initSSLContext()
760770
{
761771
if (!_ssl) {
762-
_ssl = new SSLContext;
763-
_ssl->ref();
772+
_ssl = std::make_shared<SSLContext>();
764773
}
765774
}
766775

libraries/ESP8266WiFi/src/WiFiClientSecure.h

+1-3
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,6 @@ class WiFiClientSecure : public WiFiClient {
3232
public:
3333
WiFiClientSecure();
3434
~WiFiClientSecure() override;
35-
WiFiClientSecure(const WiFiClientSecure&);
36-
WiFiClientSecure& operator=(const WiFiClientSecure&);
3735

3836
int connect(IPAddress ip, uint16_t port) override;
3937
int connect(const String host, uint16_t port) override;
@@ -91,7 +89,7 @@ friend class WiFiServerSecure; // Needs access to custom constructor below
9189
int _connectSSL(const char* hostName);
9290
bool _verifyDN(const char* name);
9391

94-
SSLContext* _ssl = nullptr;
92+
std::shared_ptr<SSLContext> _ssl = nullptr;
9593
};
9694

9795
#endif //wificlientsecure_h

0 commit comments

Comments
 (0)