@@ -74,37 +74,47 @@ typedef std::list<BufferItem> BufferList;
74
74
class SSLContext
75
75
{
76
76
public:
77
- SSLContext ()
77
+ SSLContext (bool isServer = false )
78
78
{
79
- if (_ssl_ctx_refcnt == 0 ) {
80
- _ssl_ctx = ssl_ctx_new (SSL_SERVER_VERIFY_LATER | SSL_DEBUG_OPTS | SSL_CONNECT_IN_PARTS | SSL_READ_BLOCKING | SSL_NO_DEFAULT_KEY, 0 );
79
+ _isServer = isServer;
80
+ if (!_isServer) {
81
+ if (_ssl_client_ctx_refcnt == 0 ) {
82
+ _ssl_client_ctx = ssl_ctx_new (SSL_SERVER_VERIFY_LATER | SSL_DEBUG_OPTS | SSL_CONNECT_IN_PARTS | SSL_READ_BLOCKING | SSL_NO_DEFAULT_KEY, 0 );
83
+ }
84
+ ++_ssl_client_ctx_refcnt;
85
+ } else {
86
+ if (_ssl_svr_ctx_refcnt == 0 ) {
87
+ _ssl_svr_ctx = ssl_ctx_new (SSL_SERVER_VERIFY_LATER | SSL_DEBUG_OPTS | SSL_CONNECT_IN_PARTS | SSL_READ_BLOCKING | SSL_NO_DEFAULT_KEY, 0 );
88
+ }
89
+ ++_ssl_svr_ctx_refcnt;
81
90
}
82
- ++_ssl_ctx_refcnt;
83
91
}
84
92
85
93
~SSLContext ()
86
94
{
87
- if (_ssl ) {
88
- ssl_free (_ssl );
89
- _ssl = nullptr ;
95
+ if (io_ctx ) {
96
+ io_ctx-> unref ( );
97
+ io_ctx = nullptr ;
90
98
}
91
-
92
- --_ssl_ctx_refcnt;
93
- if (_ssl_ctx_refcnt == 0 ) {
94
- ssl_ctx_free (_ssl_ctx);
99
+ _ssl = nullptr ;
100
+ if (!_isServer) {
101
+ --_ssl_client_ctx_refcnt;
102
+ if (_ssl_client_ctx_refcnt == 0 ) {
103
+ ssl_ctx_free (_ssl_client_ctx);
104
+ _ssl_client_ctx = nullptr ;
105
+ }
106
+ } else {
107
+ --_ssl_svr_ctx_refcnt;
108
+ if (_ssl_svr_ctx_refcnt == 0 ) {
109
+ ssl_ctx_free (_ssl_svr_ctx);
110
+ _ssl_svr_ctx = nullptr ;
111
+ }
95
112
}
96
113
}
97
114
98
- void ref ()
99
- {
100
- ++_refcnt;
101
- }
102
-
103
- void unref ()
115
+ static void _delete_shared_SSL (SSL *_to_del)
104
116
{
105
- if (--_refcnt == 0 ) {
106
- delete this ;
107
- }
117
+ ssl_free (_to_del);
108
118
}
109
119
110
120
void connect (ClientContext* ctx, const char * hostName, uint32_t timeout_ms)
@@ -116,50 +126,67 @@ class SSLContext
116
126
ssl_free will want to send a close notify alert, but the old TCP connection
117
127
is already gone at this point, so reset io_ctx. */
118
128
io_ctx = nullptr ;
119
- ssl_free ( _ssl) ;
129
+ _ssl = nullptr ;
120
130
_available = 0 ;
121
131
_read_ptr = nullptr ;
122
132
}
123
133
io_ctx = ctx;
124
- _ssl = ssl_client_new (_ssl_ctx, reinterpret_cast <int >(this ), nullptr , 0 , ext);
134
+ ctx->ref ();
135
+
136
+ // Wrap the new SSL with a smart pointer, custom deleter to call ssl_free
137
+ SSL *_new_ssl = ssl_client_new (_ssl_client_ctx, reinterpret_cast <int >(this ), nullptr , 0 , ext);
138
+ std::shared_ptr<SSL> _new_ssl_shared (_new_ssl, _delete_shared_SSL);
139
+ _ssl = _new_ssl_shared;
140
+
125
141
uint32_t t = millis ();
126
142
127
- while (millis () - t < timeout_ms && ssl_handshake_status (_ssl) != SSL_OK) {
143
+ while (millis () - t < timeout_ms && ssl_handshake_status (_ssl. get () ) != SSL_OK) {
128
144
uint8_t * data;
129
- int rc = ssl_read (_ssl, &data);
145
+ int rc = ssl_read (_ssl. get () , &data);
130
146
if (rc < SSL_OK) {
131
147
ssl_display_error (rc);
132
148
break ;
133
149
}
134
150
}
135
151
}
136
152
137
- void connectServer (ClientContext *ctx) {
153
+ void connectServer (ClientContext *ctx, uint32_t timeout_ms)
154
+ {
138
155
io_ctx = ctx;
139
- _ssl = ssl_server_new (_ssl_ctx, reinterpret_cast <int >(this ));
140
- _isServer = true ;
156
+ ctx->ref ();
157
+
158
+ // Wrap the new SSL with a smart pointer, custom deleter to call ssl_free
159
+ SSL *_new_ssl = ssl_server_new (_ssl_svr_ctx, reinterpret_cast <int >(this ));
160
+ std::shared_ptr<SSL> _new_ssl_shared (_new_ssl, _delete_shared_SSL);
161
+ _ssl = _new_ssl_shared;
141
162
142
- uint32_t timeout_ms = 5000 ;
143
163
uint32_t t = millis ();
144
164
145
- while (millis () - t < timeout_ms && ssl_handshake_status (_ssl) != SSL_OK) {
165
+ while (millis () - t < timeout_ms && ssl_handshake_status (_ssl. get () ) != SSL_OK) {
146
166
uint8_t * data;
147
- int rc = ssl_read (_ssl, &data);
167
+ int rc = ssl_read (_ssl. get () , &data);
148
168
if (rc < SSL_OK) {
169
+ ssl_display_error (rc);
149
170
break ;
150
171
}
151
172
}
152
173
}
153
174
154
175
void stop ()
155
176
{
177
+ if (io_ctx) {
178
+ io_ctx->unref ();
179
+ }
156
180
io_ctx = nullptr ;
157
181
}
158
182
159
183
bool connected ()
160
184
{
161
- if (_isServer) return _ssl != nullptr ;
162
- else return _ssl != nullptr && ssl_handshake_status (_ssl) == SSL_OK;
185
+ if (_isServer) {
186
+ return _ssl != nullptr ;
187
+ } else {
188
+ return _ssl != nullptr && ssl_handshake_status (_ssl.get ()) == SSL_OK;
189
+ }
163
190
}
164
191
165
192
int read (uint8_t * dst, size_t size)
@@ -289,10 +316,9 @@ class SSLContext
289
316
return loadObject (type, buf.get (), size);
290
317
}
291
318
292
-
293
319
bool loadObject (int type, const uint8_t * data, size_t size)
294
320
{
295
- int rc = ssl_obj_memory_load (_ssl_ctx , type, data, static_cast <int >(size), nullptr );
321
+ int rc = ssl_obj_memory_load (_isServer?_ssl_svr_ctx:_ssl_client_ctx , type, data, static_cast <int >(size), nullptr );
296
322
if (rc != SSL_OK) {
297
323
DEBUGV (" loadObject: ssl_obj_memory_load returned %d\n " , rc);
298
324
return false ;
@@ -302,7 +328,7 @@ class SSLContext
302
328
303
329
bool verifyCert ()
304
330
{
305
- int rc = ssl_verify_cert (_ssl);
331
+ int rc = ssl_verify_cert (_ssl. get () );
306
332
if (_allowSelfSignedCerts && rc == SSL_X509_ERROR (X509_VFY_ERROR_SELF_SIGNED)) {
307
333
DEBUGV (" Allowing self-signed certificate\n " );
308
334
return true ;
@@ -321,12 +347,16 @@ class SSLContext
321
347
322
348
operator SSL*()
323
349
{
324
- return _ssl;
350
+ return _ssl. get () ;
325
351
}
326
352
327
353
static ClientContext* getIOContext (int fd)
328
354
{
329
- return reinterpret_cast <SSLContext*>(fd)->io_ctx ;
355
+ if (fd) {
356
+ SSLContext *thisSSL = reinterpret_cast <SSLContext*>(fd);
357
+ return thisSSL->io_ctx ;
358
+ }
359
+ return nullptr ;
330
360
}
331
361
332
362
protected:
@@ -339,10 +369,9 @@ class SSLContext
339
369
optimistic_yield (100 );
340
370
341
371
uint8_t * data;
342
- int rc = ssl_read (_ssl, &data);
372
+ int rc = ssl_read (_ssl. get () , &data);
343
373
if (rc <= 0 ) {
344
374
if (rc < SSL_OK && rc != SSL_CLOSE_NOTIFY && rc != SSL_ERROR_CONN_LOST) {
345
- ssl_free (_ssl);
346
375
_ssl = nullptr ;
347
376
}
348
377
return 0 ;
@@ -359,7 +388,7 @@ class SSLContext
359
388
return 0 ;
360
389
}
361
390
362
- int rc = ssl_write (_ssl, src, size);
391
+ int rc = ssl_write (_ssl. get () , src, size);
363
392
if (rc >= 0 ) {
364
393
return rc;
365
394
}
@@ -404,19 +433,22 @@ class SSLContext
404
433
}
405
434
406
435
bool _isServer = false ;
407
- static SSL_CTX* _ssl_ctx;
408
- static int _ssl_ctx_refcnt;
409
- SSL* _ssl = nullptr ;
410
- int _refcnt = 0 ;
436
+ static SSL_CTX* _ssl_client_ctx;
437
+ static int _ssl_client_ctx_refcnt;
438
+ static SSL_CTX* _ssl_svr_ctx;
439
+ static int _ssl_svr_ctx_refcnt;
440
+ std::shared_ptr<SSL> _ssl = nullptr ;
411
441
const uint8_t * _read_ptr = nullptr ;
412
442
size_t _available = 0 ;
413
443
BufferList _writeBuffers;
414
444
bool _allowSelfSignedCerts = false ;
415
445
ClientContext* io_ctx = nullptr ;
416
446
};
417
447
418
- SSL_CTX* SSLContext::_ssl_ctx = nullptr ;
419
- int SSLContext::_ssl_ctx_refcnt = 0 ;
448
+ SSL_CTX* SSLContext::_ssl_client_ctx = nullptr ;
449
+ int SSLContext::_ssl_client_ctx_refcnt = 0 ;
450
+ SSL_CTX* SSLContext::_ssl_svr_ctx = nullptr ;
451
+ int SSLContext::_ssl_svr_ctx_refcnt = 0 ;
420
452
421
453
WiFiClientSecure::WiFiClientSecure ()
422
454
{
@@ -426,41 +458,25 @@ WiFiClientSecure::WiFiClientSecure()
426
458
427
459
WiFiClientSecure::~WiFiClientSecure ()
428
460
{
429
- if (_ssl) {
430
- _ssl->unref ();
431
- }
432
- }
433
-
434
- WiFiClientSecure::WiFiClientSecure (const WiFiClientSecure& other)
435
- : WiFiClient(static_cast <const WiFiClient&>(other))
436
- {
437
- _ssl = other._ssl ;
438
- if (_ssl) {
439
- _ssl->ref ();
440
- }
441
- }
442
-
443
- WiFiClientSecure& WiFiClientSecure::operator =(const WiFiClientSecure& rhs)
444
- {
445
- (WiFiClient&) *this = rhs;
446
- _ssl = rhs._ssl ;
447
- if (_ssl) {
448
- _ssl->ref ();
449
- }
450
- return *this ;
461
+ _ssl = nullptr ;
451
462
}
452
463
453
464
// Only called by the WifiServerSecure, need to get the keys/certs loaded before beginning
454
- WiFiClientSecure::WiFiClientSecure (ClientContext* client, bool usePMEM, const uint8_t *rsakey, int rsakeyLen, const uint8_t *cert, int certLen)
465
+ WiFiClientSecure::WiFiClientSecure (ClientContext* client, bool usePMEM,
466
+ const uint8_t *rsakey, int rsakeyLen,
467
+ const uint8_t *cert, int certLen)
455
468
{
469
+ // TLS handshake may take more than the 5 second default timeout
470
+ _timeout = 15000 ;
471
+
472
+ // We've been given the client context from the available() call
456
473
_client = client;
457
- if (_ssl) {
458
- _ssl->unref ();
459
- _ssl = nullptr ;
460
- }
474
+ _client->ref ();
461
475
462
- _ssl = new SSLContext;
463
- _ssl->ref ();
476
+ // Make the "_ssl" SSLContext, in the constructor there should be none yet
477
+ SSLContext *_new_ssl = new SSLContext (true );
478
+ std::shared_ptr<SSLContext> _new_ssl_shared (_new_ssl);
479
+ _ssl = _new_ssl_shared;
464
480
465
481
if (usePMEM) {
466
482
if (rsakey && rsakeyLen) {
@@ -477,8 +493,7 @@ WiFiClientSecure::WiFiClientSecure(ClientContext* client, bool usePMEM, const ui
477
493
_ssl->loadObject (SSL_OBJ_X509_CERT, cert, certLen);
478
494
}
479
495
}
480
- _client->ref ();
481
- _ssl->connectServer (client);
496
+ _ssl->connectServer (client, _timeout);
482
497
}
483
498
484
499
int WiFiClientSecure::connect (IPAddress ip, uint16_t port)
@@ -510,14 +525,12 @@ int WiFiClientSecure::connect(const String host, uint16_t port)
510
525
int WiFiClientSecure::_connectSSL (const char * hostName)
511
526
{
512
527
if (!_ssl) {
513
- _ssl = new SSLContext;
514
- _ssl->ref ();
528
+ _ssl = std::make_shared<SSLContext>();
515
529
}
516
530
_ssl->connect (_client, hostName, _timeout);
517
531
518
532
auto status = ssl_handshake_status (*_ssl);
519
533
if (status != SSL_OK) {
520
- _ssl->unref ();
521
534
_ssl = nullptr ;
522
535
return 0 ;
523
536
}
@@ -537,7 +550,6 @@ size_t WiFiClientSecure::write(const uint8_t *buf, size_t size)
537
550
}
538
551
539
552
if (rc != SSL_CLOSE_NOTIFY) {
540
- _ssl->unref ();
541
553
_ssl = nullptr ;
542
554
}
543
555
@@ -640,8 +652,6 @@ void WiFiClientSecure::stop()
640
652
{
641
653
if (_ssl) {
642
654
_ssl->stop ();
643
- _ssl->unref ();
644
- _ssl = nullptr ;
645
655
}
646
656
WiFiClient::stop ();
647
657
}
@@ -723,9 +733,9 @@ bool WiFiClientSecure::_verifyDN(const char* domain_name)
723
733
String domain_name_str (domain_name);
724
734
domain_name_str.toLowerCase ();
725
735
726
- const char * san = NULL ;
736
+ const char * san = nullptr ;
727
737
int i = 0 ;
728
- while ((san = ssl_get_cert_subject_alt_dnsname (*_ssl, i)) != NULL ) {
738
+ while ((san = ssl_get_cert_subject_alt_dnsname (*_ssl, i)) != nullptr ) {
729
739
String san_str (san);
730
740
san_str.toLowerCase ();
731
741
if (matchName (san_str, domain_name_str)) {
@@ -759,8 +769,7 @@ bool WiFiClientSecure::verifyCertChain(const char* domain_name)
759
769
void WiFiClientSecure::_initSSLContext ()
760
770
{
761
771
if (!_ssl) {
762
- _ssl = new SSLContext;
763
- _ssl->ref ();
772
+ _ssl = std::make_shared<SSLContext>();
764
773
}
765
774
}
766
775
0 commit comments