-
Notifications
You must be signed in to change notification settings - Fork 16
/
Copy pathArduino_ESP32_OTA.cpp
398 lines (329 loc) · 10.8 KB
/
Arduino_ESP32_OTA.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
/*
This file is part of Arduino_ESP32_OTA.
Copyright 2022 ARDUINO SA (http://www.arduino.cc/)
This software is released under the GNU General Public License version 3,
which covers the main part of arduino-cli.
The terms of this license can be found at:
https://www.gnu.org/licenses/gpl-3.0.en.html
You can be released from the requirements of the above licenses by purchasing
a commercial license. Buying such a license is mandatory if you want to modify or
otherwise use the software for commercial activities involving the Arduino
software without disclosing the source code of your own applications. To purchase
a commercial license, send an email to [email protected].
*/
/******************************************************************************
INCLUDE
******************************************************************************/
#include <Update.h>
#include "Arduino_ESP32_OTA.h"
#include "tls/amazon_root_ca.h"
#include "esp_ota_ops.h"
/******************************************************************************
CTOR/DTOR
******************************************************************************/
Arduino_ESP32_OTA::Arduino_ESP32_OTA()
: _context(nullptr)
, _client(nullptr)
, _http_client(nullptr)
,_ca_cert{amazon_root_ca}
,_ca_cert_bundle{nullptr}
,_ca_cert_bundle_size(0)
,_magic(0)
{
}
Arduino_ESP32_OTA::~Arduino_ESP32_OTA(){
clean();
}
/******************************************************************************
PUBLIC MEMBER FUNCTIONS
******************************************************************************/
Arduino_ESP32_OTA::Error Arduino_ESP32_OTA::begin(uint32_t magic)
{
/* ... configure board Magic number */
setMagic(magic);
if(!isCapable()) {
DEBUG_ERROR("%s: board is not capable to perform OTA", __FUNCTION__);
return Error::NoOtaStorage;
}
if(Update.isRunning()) {
Update.abort();
DEBUG_DEBUG("%s: Aborting running update", __FUNCTION__);
}
if(!Update.begin(UPDATE_SIZE_UNKNOWN)) {
DEBUG_ERROR("%s: failed to initialize flash update", __FUNCTION__);
return Error::OtaStorageInit;
}
return Error::None;
}
void Arduino_ESP32_OTA::setCACert (const char *rootCA)
{
if(rootCA != nullptr) {
_ca_cert = rootCA;
}
}
void Arduino_ESP32_OTA::setCACertBundle (const uint8_t * bundle)
{
if(bundle != nullptr) {
_ca_cert_bundle = bundle;
}
}
void Arduino_ESP32_OTA::setCACertBundle (const uint8_t * bundle, size_t size)
{
if(bundle != nullptr && size != 0) {
_ca_cert_bundle = bundle;
_ca_cert_bundle_size = size;
}
}
void Arduino_ESP32_OTA::setMagic(uint32_t magic)
{
_magic = magic;
}
void Arduino_ESP32_OTA::write_byte_to_flash(uint8_t data)
{
Update.write(&data, 1);
}
int Arduino_ESP32_OTA::startDownload(const char * ota_url)
{
assert(_context == nullptr);
assert(_client == nullptr);
assert(_http_client == nullptr);
Error err = Error::None;
int statusCode;
int res;
_context = new Context(ota_url, [this](uint8_t data){
_context->writtenBytes++;
write_byte_to_flash(data);
});
if(strcmp(_context->parsed_url.schema(), "http") == 0) {
_client = new WiFiClient();
} else if(strcmp(_context->parsed_url.schema(), "https") == 0) {
_client = new WiFiClientSecure();
if (_ca_cert != nullptr) {
static_cast<WiFiClientSecure*>(_client)->setCACert(_ca_cert);
}
#if (ESP_ARDUINO_VERSION < ESP_ARDUINO_VERSION_VAL(3, 0, 4))
else if (_ca_cert_bundle != nullptr) {
static_cast<WiFiClientSecure*>(_client)->setCACertBundle(_ca_cert_bundle);
}
#else
else if (_ca_cert_bundle != nullptr && _ca_cert_bundle_size != 0) {
static_cast<WiFiClientSecure*>(_client)->setCACertBundle(_ca_cert_bundle, _ca_cert_bundle_size);
}
#endif
else {
DEBUG_VERBOSE("%s: CA not configured for download client");
}
} else {
err = Error::UrlParseError;
goto exit;
}
_http_client = new HttpClient(*_client, _context->parsed_url.host(), _context->parsed_url.port());
res= _http_client->get(_context->parsed_url.path());
if(res == HTTP_ERROR_CONNECTION_FAILED) {
DEBUG_VERBOSE("OTA ERROR: http client error connecting to server \"%s:%d\"",
_context->parsed_url.host(), _context->parsed_url.port());
err = Error::ServerConnectError;
goto exit;
} else if(res == HTTP_ERROR_TIMED_OUT) {
DEBUG_VERBOSE("OTA ERROR: http client timeout \"%s\"", _context->url);
err = Error::OtaHeaderTimeout;
goto exit;
} else if(res != HTTP_SUCCESS) {
DEBUG_VERBOSE("OTA ERROR: http client returned %d on get \"%s\"", res, _context->url);
err = Error::OtaDownload;
goto exit;
}
statusCode = _http_client->responseStatusCode();
if(statusCode != 200) {
DEBUG_VERBOSE("OTA ERROR: get response on \"%s\" returned status %d", _context->url, statusCode);
err = Error::HttpResponse;
goto exit;
}
// The following call is required to save the header value , keep it
if(_http_client->contentLength() == HttpClient::kNoContentLengthHeader) {
DEBUG_VERBOSE("OTA ERROR: the response header doesn't contain \"ContentLength\" field");
err = Error::HttpHeaderError;
goto exit;
}
exit:
if(err != Error::None) {
clean();
return static_cast<int>(err);
} else {
return _http_client->contentLength();
}
}
int Arduino_ESP32_OTA::downloadPoll()
{
int http_res = static_cast<int>(Error::None);;
int res = 0;
if(_http_client->available() == 0) {
goto exit;
}
http_res = _http_client->read(_context->buffer, _context->buf_len);
if(http_res < 0) {
DEBUG_VERBOSE("OTA ERROR: Download read error %d", http_res);
res = static_cast<int>(Error::OtaDownload);
goto exit;
}
for(uint8_t* cursor=(uint8_t*)_context->buffer; cursor<_context->buffer+http_res; ) {
switch(_context->downloadState) {
case OtaDownloadHeader: {
uint32_t copied = http_res < sizeof(_context->header.buf) ? http_res : sizeof(_context->header.buf);
memcpy(_context->header.buf+_context->headerCopiedBytes, _context->buffer, copied);
cursor += copied;
_context->headerCopiedBytes += copied;
// when finished go to next state
if(sizeof(_context->header.buf) == _context->headerCopiedBytes) {
_context->downloadState = OtaDownloadFile;
_context->calculatedCrc32 = crc_update(
_context->calculatedCrc32,
&(_context->header.header.magic_number),
sizeof(_context->header) - offsetof(OtaHeader, header.magic_number)
);
if(_context->header.header.magic_number != _magic) {
_context->downloadState = OtaDownloadMagicNumberMismatch;
res = static_cast<int>(Error::OtaHeaderMagicNumber);
goto exit;
}
}
break;
}
case OtaDownloadFile:
_context->decoder.decompress(cursor, http_res - (cursor-_context->buffer)); // TODO verify return value
_context->calculatedCrc32 = crc_update(
_context->calculatedCrc32,
cursor,
http_res - (cursor-_context->buffer)
);
cursor += http_res - (cursor-_context->buffer);
_context->downloadedSize += (cursor-_context->buffer);
// TODO there should be no more bytes available when the download is completed
if(_context->downloadedSize == _http_client->contentLength()) {
_context->downloadState = OtaDownloadCompleted;
res = 1;
}
if(_context->downloadedSize > _http_client->contentLength()) {
_context->downloadState = OtaDownloadError;
res = static_cast<int>(Error::OtaDownload);
}
// TODO fail if we exceed a timeout? and available is 0 (client is broken)
break;
case OtaDownloadCompleted:
res = 1;
goto exit;
default:
_context->downloadState = OtaDownloadError;
res = static_cast<int>(Error::OtaDownload);
goto exit;
}
}
exit:
if(_context->downloadState == OtaDownloadError ||
_context->downloadState == OtaDownloadMagicNumberMismatch) {
clean(); // need to clean everything because the download failed
} else if(_context->downloadState == OtaDownloadCompleted) {
// only need to delete clients and not the context, since it will be needed
if(_client != nullptr) {
delete _client;
_client = nullptr;
}
if(_http_client != nullptr) {
delete _http_client;
_http_client = nullptr;
}
}
return res;
}
int Arduino_ESP32_OTA::downloadProgress()
{
if(_context->error != Error::None) {
return static_cast<int>(_context->error);
} else {
return _context->downloadedSize;
}
}
size_t Arduino_ESP32_OTA::downloadSize()
{
return _http_client!=nullptr ? _http_client->contentLength() : 0;
}
int Arduino_ESP32_OTA::download(const char * ota_url)
{
int err = startDownload(ota_url);
if(err < 0) {
return err;
}
int res = 0;
while((res = downloadPoll()) <= 0);
return res == 1? _context->writtenBytes : res;
}
void Arduino_ESP32_OTA::clean()
{
if(_client != nullptr) {
delete _client;
_client = nullptr;
}
if(_http_client != nullptr) {
delete _http_client;
_http_client = nullptr;
}
if(_context != nullptr) {
delete _context;
_context = nullptr;
}
}
Arduino_ESP32_OTA::Error Arduino_ESP32_OTA::verify()
{
assert(_context != nullptr);
/* ... then finalize ... */
_context->calculatedCrc32 ^= 0xFFFFFFFF;
/* Verify the crc */
if(_context->header.header.crc32 != _context->calculatedCrc32) {
DEBUG_ERROR("%s: CRC32 mismatch", __FUNCTION__);
return Error::OtaHeaderCrc;
}
clean();
return Error::None;
}
Arduino_ESP32_OTA::Error Arduino_ESP32_OTA::update()
{
Arduino_ESP32_OTA::Error res = Error::None;
if(_context != nullptr && (res = verify()) != Error::None) {
return res;
}
if (!Update.end(true)) {
DEBUG_ERROR("%s: Failure to apply OTA update", __FUNCTION__);
return Error::OtaStorageEnd;
}
return res;
}
void Arduino_ESP32_OTA::reset()
{
ESP.restart();
}
bool Arduino_ESP32_OTA::isCapable()
{
const esp_partition_t * ota_0 = esp_partition_find_first(ESP_PARTITION_TYPE_APP, ESP_PARTITION_SUBTYPE_APP_OTA_0, NULL);
const esp_partition_t * ota_1 = esp_partition_find_first(ESP_PARTITION_TYPE_APP, ESP_PARTITION_SUBTYPE_APP_OTA_1, NULL);
return ((ota_0 != nullptr) && (ota_1 != nullptr));
}
/******************************************************************************
PROTECTED MEMBER FUNCTIONS
******************************************************************************/
Arduino_ESP32_OTA::Context::Context(
const char* url, std::function<void(uint8_t)> putc)
: url((char*)malloc(strlen(url)+1))
, parsed_url(url)
, downloadState(OtaDownloadHeader)
, calculatedCrc32(0xFFFFFFFF)
, headerCopiedBytes(0)
, downloadedSize(0)
, writtenBytes(0)
, error(Error::None)
, decoder(putc) {
strcpy(this->url, url);
}
Arduino_ESP32_OTA::Context::~Context(){
free(url);
url = nullptr;
}