diff --git a/src/Arduino_ESP32_OTA.cpp b/src/Arduino_ESP32_OTA.cpp index 13a7d92..d500961 100644 --- a/src/Arduino_ESP32_OTA.cpp +++ b/src/Arduino_ESP32_OTA.cpp @@ -26,26 +26,6 @@ #include "decompress/utility.h" #include "esp_ota_ops.h" -/* Used to bind local module function to actual class instance */ -static Arduino_ESP32_OTA * _esp_ota_obj_ptr = 0; - -/****************************************************************************** - LOCAL MODULE FUNCTIONS - ******************************************************************************/ - -static uint8_t read_byte() { - if(_esp_ota_obj_ptr) { - return _esp_ota_obj_ptr->read_byte_from_network(); - } - return -1; -} - -static void write_byte(uint8_t data) { - if(_esp_ota_obj_ptr) { - _esp_ota_obj_ptr->write_byte_to_flash(data); - } -} - /****************************************************************************** CTOR/DTOR ******************************************************************************/ @@ -57,6 +37,7 @@ Arduino_ESP32_OTA::Arduino_ESP32_OTA() ,_crc32(0) ,_ca_cert{amazon_root_ca} ,_ca_cert_bundle{nullptr} +,_magic(0) { } @@ -65,22 +46,22 @@ Arduino_ESP32_OTA::Arduino_ESP32_OTA() PUBLIC MEMBER FUNCTIONS ******************************************************************************/ -Arduino_ESP32_OTA::Error Arduino_ESP32_OTA::begin() +Arduino_ESP32_OTA::Error Arduino_ESP32_OTA::begin(uint32_t magic) { - _esp_ota_obj_ptr = this; + /* initialize private variables */ + otaInit(); /* ... initialize CRC ... */ - _crc32 = 0xFFFFFFFF; + crc32Init(); + + /* ... configure board Magic number */ + setMagic(magic); if(!isCapable()) { DEBUG_ERROR("%s: board is not capable to perform OTA", __FUNCTION__); return Error::NoOtaStorage; } - /* initialize private variables */ - _ota_size = 0; - _ota_header = {0}; - if(Update.isRunning()) { Update.abort(); DEBUG_DEBUG("%s: Aborting running update", __FUNCTION__); @@ -107,6 +88,11 @@ void Arduino_ESP32_OTA::setCACertBundle (const uint8_t * bundle) } } +void Arduino_ESP32_OTA::setMagic(uint32_t magic) +{ + _magic = magic; +} + uint8_t Arduino_ESP32_OTA::read_byte_from_network() { bool is_http_data_timeout = false; @@ -119,7 +105,7 @@ uint8_t Arduino_ESP32_OTA::read_byte_from_network() } if (_client->available()) { const uint8_t data = _client->read(); - _crc32 = crc_update(_crc32, &data, 1); + crc32Update(data); return data; } } @@ -262,7 +248,7 @@ int Arduino_ESP32_OTA::download(const char * ota_url) } /* ... and OTA magic number */ - if (_ota_header.header.magic_number != ARDUINO_ESP32_OTA_MAGIC) + if (_ota_header.header.magic_number != _magic) { delete _client; _client = nullptr; @@ -273,7 +259,7 @@ int Arduino_ESP32_OTA::download(const char * ota_url) _crc32 = crc_update(_crc32, &_ota_header.header.magic_number, 12); /* Download and decode OTA file */ - _ota_size = lzss_download(read_byte, write_byte, content_length_val - sizeof(_ota_header)); + _ota_size = lzss_download(this, content_length_val - sizeof(_ota_header)); if(_ota_size <= content_length_val - sizeof(_ota_header)) { @@ -289,10 +275,10 @@ int Arduino_ESP32_OTA::download(const char * ota_url) Arduino_ESP32_OTA::Error Arduino_ESP32_OTA::update() { - /* ... then finalise ... */ - _crc32 ^= 0xFFFFFFFF; + /* ... then finalize ... */ + crc32Finalize(); - if(_crc32 != _ota_header.header.crc32) { + if(!crc32Verify()) { DEBUG_ERROR("%s: CRC32 mismatch", __FUNCTION__); return Error::OtaHeaderCrc; } @@ -316,3 +302,33 @@ bool Arduino_ESP32_OTA::isCapable() const esp_partition_t * ota_1 = esp_partition_find_first(ESP_PARTITION_TYPE_APP, ESP_PARTITION_SUBTYPE_APP_OTA_1, NULL); return ((ota_0 != nullptr) && (ota_1 != nullptr)); } + +/****************************************************************************** + PROTECTED MEMBER FUNCTIONS + ******************************************************************************/ + +void Arduino_ESP32_OTA::otaInit() +{ + _ota_size = 0; + _ota_header = {0}; +} + +void Arduino_ESP32_OTA::crc32Init() +{ + _crc32 = 0xFFFFFFFF; +} + +void Arduino_ESP32_OTA::crc32Update(const uint8_t data) +{ + _crc32 = crc_update(_crc32, &data, 1); +} + +void Arduino_ESP32_OTA::crc32Finalize() +{ + _crc32 ^= 0xFFFFFFFF; +} + +bool Arduino_ESP32_OTA::crc32Verify() +{ + return (_crc32 == _ota_header.header.crc32); +} diff --git a/src/Arduino_ESP32_OTA.h b/src/Arduino_ESP32_OTA.h index 069d22b..34046d8 100644 --- a/src/Arduino_ESP32_OTA.h +++ b/src/Arduino_ESP32_OTA.h @@ -42,13 +42,6 @@ static uint32_t const ARDUINO_ESP32_OTA_HTTP_HEADER_RECEIVE_TIMEOUT_ms = 10000; static uint32_t const ARDUINO_ESP32_OTA_BINARY_HEADER_RECEIVE_TIMEOUT_ms = 10000; static uint32_t const ARDUINO_ESP32_OTA_BINARY_BYTE_RECEIVE_TIMEOUT_ms = 2000; -/****************************************************************************** - * TYPEDEF - ******************************************************************************/ - -typedef uint8_t(*ArduinoEsp32OtaReadByteFuncPointer)(void); -typedef void(*ArduinoEsp32OtaWriteByteFuncPointer)(uint8_t); - /****************************************************************************** * CLASS DECLARATION ******************************************************************************/ @@ -79,24 +72,34 @@ class Arduino_ESP32_OTA Arduino_ESP32_OTA(); virtual ~Arduino_ESP32_OTA() { } - Arduino_ESP32_OTA::Error begin(); - void setCACert (const char *rootCA); + Arduino_ESP32_OTA::Error begin(uint32_t magic = ARDUINO_ESP32_OTA_MAGIC); + void setMagic(uint32_t magic); + void setCACert(const char *rootCA); void setCACertBundle(const uint8_t * bundle); int download(const char * ota_url); uint8_t read_byte_from_network(); - void write_byte_to_flash(uint8_t data); + virtual void write_byte_to_flash(uint8_t data); Arduino_ESP32_OTA::Error update(); void reset(); static bool isCapable(); -private: +protected: + + void otaInit(); + void crc32Init(); + void crc32Update(const uint8_t data); + void crc32Finalize(); + bool crc32Verify(); +private: Client * _client; OtaHeader _ota_header; size_t _ota_size; uint32_t _crc32; const char * _ca_cert; const uint8_t * _ca_cert_bundle; + uint32_t _magic; + }; #endif /* ARDUINO_ESP32_OTA_H_ */ diff --git a/src/decompress/lzss.cpp b/src/decompress/lzss.cpp index bce0e1c..d453c15 100644 --- a/src/decompress/lzss.cpp +++ b/src/decompress/lzss.cpp @@ -22,9 +22,10 @@ GLOBAL VARIABLES **************************************************************************************/ +/* Used to bind local module function to actual class instance */ +static Arduino_ESP32_OTA * esp_ota_obj_ptr = 0; + static size_t LZSS_FILE_SIZE = 0; -static ArduinoEsp32OtaReadByteFuncPointer read_byte_fptr = 0; -static ArduinoEsp32OtaWriteByteFuncPointer write_byte_fptr = 0; int bit_buffer = 0, bit_mask = 128; unsigned char buffer[N * 2]; @@ -38,7 +39,7 @@ static size_t bytes_read_fgetc = 0; void lzss_fputc(int const c) { - write_byte_fptr((uint8_t)c); + esp_ota_obj_ptr->write_byte_to_flash((uint8_t)c); /* write byte callback */ bytes_written_fputc++; @@ -56,7 +57,7 @@ int lzss_fgetc() return LZSS_EOF; /* read byte callback */ - uint8_t const c = read_byte_fptr(); + uint8_t const c = esp_ota_obj_ptr->read_byte_from_network(); bytes_read_fgetc++; return c; @@ -157,10 +158,9 @@ void lzss_decode(void) PUBLIC FUNCTIONS **************************************************************************************/ -int lzss_download(ArduinoEsp32OtaReadByteFuncPointer read_byte, ArduinoEsp32OtaWriteByteFuncPointer write_byte, size_t const lzss_file_size) +int lzss_download(Arduino_ESP32_OTA * instance, size_t const lzss_file_size) { - read_byte_fptr = read_byte; - write_byte_fptr = write_byte; + esp_ota_obj_ptr = instance; LZSS_FILE_SIZE = lzss_file_size; bytes_written_fputc = 0; bytes_read_fgetc = 0; diff --git a/src/decompress/lzss.h b/src/decompress/lzss.h index 3ddc3f2..e88cf20 100644 --- a/src/decompress/lzss.h +++ b/src/decompress/lzss.h @@ -11,6 +11,6 @@ FUNCTION DEFINITION **************************************************************************************/ -int lzss_download(ArduinoEsp32OtaReadByteFuncPointer read_byte, ArduinoEsp32OtaWriteByteFuncPointer write_byte, size_t const lzss_file_size); +int lzss_download(Arduino_ESP32_OTA * instance, size_t const lzss_file_size); #endif /* SSU_LZSS_H_ */