@@ -53,7 +53,7 @@ static int verify_digest(SSL *ssl, int mode, const uint8_t *buf, int read_len);
53
53
static void * crypt_new (SSL * ssl , uint8_t * key , uint8_t * iv , int is_decrypt , void * cached );
54
54
static int send_raw_packet (SSL * ssl , uint8_t protocol );
55
55
static void certificate_free (SSL * ssl );
56
- static int increase_bm_data_size (SSL * ssl );
56
+ static int increase_bm_data_size (SSL * ssl , size_t size );
57
57
58
58
/**
59
59
* The server will pick the cipher based on the order that the order that the
@@ -285,6 +285,11 @@ EXP_FUNC int STDCALL ssl_write(SSL *ssl, const uint8_t *out_data, int out_len)
285
285
{
286
286
int n = out_len , nw , i , tot = 0 ;
287
287
/* maximum size of a TLS packet is around 16kB, so fragment */
288
+
289
+ if (ssl -> can_free_certificates ) {
290
+ certificate_free (ssl );
291
+ }
292
+
288
293
do
289
294
{
290
295
nw = n ;
@@ -545,9 +550,9 @@ SSL *ssl_new(SSL_CTX *ssl_ctx, int client_fd)
545
550
ssl -> flag = SSL_NEED_RECORD ;
546
551
ssl -> bm_data = ssl -> bm_all_data + BM_RECORD_OFFSET ; /* space at the start */
547
552
ssl -> hs_status = SSL_NOT_OK ; /* not connected */
548
- ssl -> can_increase_data_size = false;
549
553
#ifdef CONFIG_ENABLE_VERIFICATION
550
554
ssl -> ca_cert_ctx = ssl_ctx -> ca_cert_ctx ;
555
+ ssl -> can_free_certificates = false;
551
556
#endif
552
557
disposable_new (ssl );
553
558
@@ -1214,6 +1219,10 @@ int basic_read(SSL *ssl, uint8_t **in_data)
1214
1219
int read_len , is_client = IS_SET_SSL_FLAG (SSL_IS_CLIENT );
1215
1220
uint8_t * buf = ssl -> bm_data ;
1216
1221
1222
+ if (ssl -> can_free_certificates ) {
1223
+ certificate_free (ssl );
1224
+ }
1225
+
1217
1226
read_len = SOCKET_READ (ssl -> client_fd , & buf [ssl -> bm_read_index ],
1218
1227
ssl -> need_bytes - ssl -> got_bytes );
1219
1228
@@ -1287,16 +1296,8 @@ int basic_read(SSL *ssl, uint8_t **in_data)
1287
1296
if (ssl -> need_bytes > ssl -> max_plain_length + RT_EXTRA - BM_RECORD_OFFSET )
1288
1297
{
1289
1298
printf ("ssl->need_bytes=%d > %d\r\n" , ssl -> need_bytes , ssl -> max_plain_length + RT_EXTRA - BM_RECORD_OFFSET );
1290
- if (ssl -> can_increase_data_size )
1291
- {
1292
- ret = increase_bm_data_size (ssl );
1293
- if (ret != SSL_OK )
1294
- {
1295
- ret = SSL_ERROR_INVALID_PROT_MSG ;
1296
- goto error ;
1297
- }
1298
- }
1299
- else
1299
+ ret = increase_bm_data_size (ssl , ssl -> need_bytes + BM_RECORD_OFFSET - RT_EXTRA );
1300
+ if (ret != SSL_OK )
1300
1301
{
1301
1302
ret = SSL_ERROR_INVALID_PROT_MSG ;
1302
1303
goto error ;
@@ -1414,24 +1415,22 @@ int basic_read(SSL *ssl, uint8_t **in_data)
1414
1415
return ret ;
1415
1416
}
1416
1417
1417
- int increase_bm_data_size (SSL * ssl )
1418
+ int increase_bm_data_size (SSL * ssl , size_t size )
1418
1419
{
1419
- if (!ssl -> can_increase_data_size ||
1420
- ssl -> max_plain_length == RT_MAX_PLAIN_LENGTH ) {
1420
+ if (ssl -> max_plain_length == RT_MAX_PLAIN_LENGTH ) {
1421
1421
return SSL_OK ;
1422
1422
}
1423
- certificate_free (ssl );
1424
- free (ssl -> bm_all_data );
1425
- ssl -> bm_data = 0 ;
1426
- ssl -> bm_all_data = malloc (RT_MAX_PLAIN_LENGTH + RT_EXTRA );
1427
- if (!ssl -> bm_all_data ) {
1423
+ size_t required = (size + 1023 ) & ~(1023 ); // round up to 1k
1424
+ required = (required < RT_MAX_PLAIN_LENGTH ) ? required : RT_MAX_PLAIN_LENGTH ;
1425
+ uint8_t * new_bm_all_data = (uint8_t * ) realloc (ssl -> bm_all_data , required + RT_EXTRA );
1426
+ if (!new_bm_all_data ) {
1428
1427
printf ("failed to grow plain buffer\r\n" );
1429
1428
ssl -> hs_status = SSL_ERROR_DEAD ;
1430
1429
return SSL_ERROR_CONN_LOST ;
1431
1430
}
1432
- ssl -> can_increase_data_size = false;
1433
- ssl -> max_plain_length = RT_MAX_PLAIN_LENGTH ;
1431
+ ssl -> bm_all_data = new_bm_all_data ;
1434
1432
ssl -> bm_data = ssl -> bm_all_data + BM_RECORD_OFFSET ;
1433
+ ssl -> max_plain_length = required ;
1435
1434
return SSL_OK ;
1436
1435
}
1437
1436
@@ -1689,6 +1688,7 @@ void disposable_free(SSL *ssl)
1689
1688
free (ssl -> dc );
1690
1689
ssl -> dc = NULL ;
1691
1690
}
1691
+ ssl -> can_free_certificates = true;
1692
1692
}
1693
1693
1694
1694
static void certificate_free (SSL * ssl )
@@ -1698,6 +1698,7 @@ static void certificate_free(SSL* ssl)
1698
1698
x509_free (ssl -> x509_ctx );
1699
1699
ssl -> x509_ctx = 0 ;
1700
1700
}
1701
+ ssl -> can_free_certificates = false;
1701
1702
#endif
1702
1703
}
1703
1704
0 commit comments