Skip to content

Allow subclassing #23

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Nov 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 49 additions & 33 deletions src/Arduino_ESP32_OTA.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,26 +26,6 @@
#include "decompress/utility.h"
#include "esp_ota_ops.h"

/* Used to bind local module function to actual class instance */
static Arduino_ESP32_OTA * _esp_ota_obj_ptr = 0;

/******************************************************************************
LOCAL MODULE FUNCTIONS
******************************************************************************/

static uint8_t read_byte() {
if(_esp_ota_obj_ptr) {
return _esp_ota_obj_ptr->read_byte_from_network();
}
return -1;
}

static void write_byte(uint8_t data) {
if(_esp_ota_obj_ptr) {
_esp_ota_obj_ptr->write_byte_to_flash(data);
}
}

/******************************************************************************
CTOR/DTOR
******************************************************************************/
Expand All @@ -57,6 +37,7 @@ Arduino_ESP32_OTA::Arduino_ESP32_OTA()
,_crc32(0)
,_ca_cert{amazon_root_ca}
,_ca_cert_bundle{nullptr}
,_magic(0)
{

}
Expand All @@ -65,22 +46,22 @@ Arduino_ESP32_OTA::Arduino_ESP32_OTA()
PUBLIC MEMBER FUNCTIONS
******************************************************************************/

Arduino_ESP32_OTA::Error Arduino_ESP32_OTA::begin()
Arduino_ESP32_OTA::Error Arduino_ESP32_OTA::begin(uint32_t magic)
{
_esp_ota_obj_ptr = this;
/* initialize private variables */
otaInit();

/* ... initialize CRC ... */
_crc32 = 0xFFFFFFFF;
crc32Init();

/* ... configure board Magic number */
setMagic(magic);

if(!isCapable()) {
DEBUG_ERROR("%s: board is not capable to perform OTA", __FUNCTION__);
return Error::NoOtaStorage;
}

/* initialize private variables */
_ota_size = 0;
_ota_header = {0};

if(Update.isRunning()) {
Update.abort();
DEBUG_DEBUG("%s: Aborting running update", __FUNCTION__);
Expand All @@ -107,6 +88,11 @@ void Arduino_ESP32_OTA::setCACertBundle (const uint8_t * bundle)
}
}

void Arduino_ESP32_OTA::setMagic(uint32_t magic)
{
_magic = magic;
}

uint8_t Arduino_ESP32_OTA::read_byte_from_network()
{
bool is_http_data_timeout = false;
Expand All @@ -119,7 +105,7 @@ uint8_t Arduino_ESP32_OTA::read_byte_from_network()
}
if (_client->available()) {
const uint8_t data = _client->read();
_crc32 = crc_update(_crc32, &data, 1);
crc32Update(data);
return data;
}
}
Expand Down Expand Up @@ -262,7 +248,7 @@ int Arduino_ESP32_OTA::download(const char * ota_url)
}

/* ... and OTA magic number */
if (_ota_header.header.magic_number != ARDUINO_ESP32_OTA_MAGIC)
if (_ota_header.header.magic_number != _magic)
{
delete _client;
_client = nullptr;
Expand All @@ -273,7 +259,7 @@ int Arduino_ESP32_OTA::download(const char * ota_url)
_crc32 = crc_update(_crc32, &_ota_header.header.magic_number, 12);

/* Download and decode OTA file */
_ota_size = lzss_download(read_byte, write_byte, content_length_val - sizeof(_ota_header));
_ota_size = lzss_download(this, content_length_val - sizeof(_ota_header));

if(_ota_size <= content_length_val - sizeof(_ota_header))
{
Expand All @@ -289,10 +275,10 @@ int Arduino_ESP32_OTA::download(const char * ota_url)

Arduino_ESP32_OTA::Error Arduino_ESP32_OTA::update()
{
/* ... then finalise ... */
_crc32 ^= 0xFFFFFFFF;
/* ... then finalize ... */
crc32Finalize();

if(_crc32 != _ota_header.header.crc32) {
if(!crc32Verify()) {
DEBUG_ERROR("%s: CRC32 mismatch", __FUNCTION__);
return Error::OtaHeaderCrc;
}
Expand All @@ -316,3 +302,33 @@ bool Arduino_ESP32_OTA::isCapable()
const esp_partition_t * ota_1 = esp_partition_find_first(ESP_PARTITION_TYPE_APP, ESP_PARTITION_SUBTYPE_APP_OTA_1, NULL);
return ((ota_0 != nullptr) && (ota_1 != nullptr));
}

/******************************************************************************
PROTECTED MEMBER FUNCTIONS
******************************************************************************/

void Arduino_ESP32_OTA::otaInit()
{
_ota_size = 0;
_ota_header = {0};
}

void Arduino_ESP32_OTA::crc32Init()
{
_crc32 = 0xFFFFFFFF;
}

void Arduino_ESP32_OTA::crc32Update(const uint8_t data)
{
_crc32 = crc_update(_crc32, &data, 1);
}

void Arduino_ESP32_OTA::crc32Finalize()
{
_crc32 ^= 0xFFFFFFFF;
}

bool Arduino_ESP32_OTA::crc32Verify()
{
return (_crc32 == _ota_header.header.crc32);
}
25 changes: 14 additions & 11 deletions src/Arduino_ESP32_OTA.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,6 @@ static uint32_t const ARDUINO_ESP32_OTA_HTTP_HEADER_RECEIVE_TIMEOUT_ms = 10000;
static uint32_t const ARDUINO_ESP32_OTA_BINARY_HEADER_RECEIVE_TIMEOUT_ms = 10000;
static uint32_t const ARDUINO_ESP32_OTA_BINARY_BYTE_RECEIVE_TIMEOUT_ms = 2000;

/******************************************************************************
* TYPEDEF
******************************************************************************/

typedef uint8_t(*ArduinoEsp32OtaReadByteFuncPointer)(void);
typedef void(*ArduinoEsp32OtaWriteByteFuncPointer)(uint8_t);

/******************************************************************************
* CLASS DECLARATION
******************************************************************************/
Expand Down Expand Up @@ -79,24 +72,34 @@ class Arduino_ESP32_OTA
Arduino_ESP32_OTA();
virtual ~Arduino_ESP32_OTA() { }

Arduino_ESP32_OTA::Error begin();
void setCACert (const char *rootCA);
Arduino_ESP32_OTA::Error begin(uint32_t magic = ARDUINO_ESP32_OTA_MAGIC);
void setMagic(uint32_t magic);
void setCACert(const char *rootCA);
void setCACertBundle(const uint8_t * bundle);
int download(const char * ota_url);
uint8_t read_byte_from_network();
void write_byte_to_flash(uint8_t data);
virtual void write_byte_to_flash(uint8_t data);
Arduino_ESP32_OTA::Error update();
void reset();
static bool isCapable();

private:
protected:

void otaInit();
void crc32Init();
void crc32Update(const uint8_t data);
void crc32Finalize();
bool crc32Verify();

private:
Client * _client;
OtaHeader _ota_header;
size_t _ota_size;
uint32_t _crc32;
const char * _ca_cert;
const uint8_t * _ca_cert_bundle;
uint32_t _magic;

};

#endif /* ARDUINO_ESP32_OTA_H_ */
14 changes: 7 additions & 7 deletions src/decompress/lzss.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,10 @@
GLOBAL VARIABLES
**************************************************************************************/

/* Used to bind local module function to actual class instance */
static Arduino_ESP32_OTA * esp_ota_obj_ptr = 0;

static size_t LZSS_FILE_SIZE = 0;
static ArduinoEsp32OtaReadByteFuncPointer read_byte_fptr = 0;
static ArduinoEsp32OtaWriteByteFuncPointer write_byte_fptr = 0;

int bit_buffer = 0, bit_mask = 128;
unsigned char buffer[N * 2];
Expand All @@ -38,7 +39,7 @@ static size_t bytes_read_fgetc = 0;

void lzss_fputc(int const c)
{
write_byte_fptr((uint8_t)c);
esp_ota_obj_ptr->write_byte_to_flash((uint8_t)c);

/* write byte callback */
bytes_written_fputc++;
Expand All @@ -56,7 +57,7 @@ int lzss_fgetc()
return LZSS_EOF;

/* read byte callback */
uint8_t const c = read_byte_fptr();
uint8_t const c = esp_ota_obj_ptr->read_byte_from_network();
bytes_read_fgetc++;

return c;
Expand Down Expand Up @@ -157,10 +158,9 @@ void lzss_decode(void)
PUBLIC FUNCTIONS
**************************************************************************************/

int lzss_download(ArduinoEsp32OtaReadByteFuncPointer read_byte, ArduinoEsp32OtaWriteByteFuncPointer write_byte, size_t const lzss_file_size)
int lzss_download(Arduino_ESP32_OTA * instance, size_t const lzss_file_size)
{
read_byte_fptr = read_byte;
write_byte_fptr = write_byte;
esp_ota_obj_ptr = instance;
LZSS_FILE_SIZE = lzss_file_size;
bytes_written_fputc = 0;
bytes_read_fgetc = 0;
Expand Down
2 changes: 1 addition & 1 deletion src/decompress/lzss.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,6 @@
FUNCTION DEFINITION
**************************************************************************************/

int lzss_download(ArduinoEsp32OtaReadByteFuncPointer read_byte, ArduinoEsp32OtaWriteByteFuncPointer write_byte, size_t const lzss_file_size);
int lzss_download(Arduino_ESP32_OTA * instance, size_t const lzss_file_size);

#endif /* SSU_LZSS_H_ */