Skip to content

Commit a5dcd3d

Browse files
authored
Merge pull request #520 from pennam/main-ota-chunked
OTA: chunked download
2 parents 07da25e + 0f53459 commit a5dcd3d

File tree

5 files changed

+122
-50
lines changed

5 files changed

+122
-50
lines changed

src/ArduinoIoTCloudTCP.h

+11-2
Original file line numberDiff line numberDiff line change
@@ -96,9 +96,18 @@ class ArduinoIoTCloudTCP: public ArduinoIoTCloudClass
9696
_get_ota_confirmation = cb;
9797

9898
if(_get_ota_confirmation) {
99-
_ota.setOtaPolicies(OTACloudProcessInterface::ApprovalRequired);
99+
_ota.enableOtaPolicy(OTACloudProcessInterface::ApprovalRequired);
100100
} else {
101-
_ota.setOtaPolicies(OTACloudProcessInterface::None);
101+
_ota.disableOtaPolicy(OTACloudProcessInterface::ApprovalRequired);
102+
}
103+
}
104+
105+
/* Slower but more reliable in some corner cases */
106+
void setOTAChunkMode(bool enable = true) {
107+
if(enable) {
108+
_ota.enableOtaPolicy(OTACloudProcessInterface::ChunkDownload);
109+
} else {
110+
_ota.disableOtaPolicy(OTACloudProcessInterface::ChunkDownload);
102111
}
103112
}
104113
#endif

src/ota/interface/OTAInterface.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -167,10 +167,10 @@ OTACloudProcessInterface::State OTACloudProcessInterface::idle(Message* msg) {
167167
OTACloudProcessInterface::State OTACloudProcessInterface::otaAvailable() {
168168
// depending on the policy decided on this device the ota process can start immediately
169169
// or wait for confirmation from the user
170-
if((policies & (ApprovalRequired | Approved)) == ApprovalRequired ) {
170+
if(getOtaPolicy(ApprovalRequired) && !getOtaPolicy(Approved)) {
171171
return OtaAvailable;
172172
} else {
173-
policies &= ~Approved;
173+
disableOtaPolicy(Approved);
174174
return StartOTA;
175175
} // TODO add an abortOTA command? in this case delete the context
176176
}

src/ota/interface/OTAInterface.h

+7-2
Original file line numberDiff line numberDiff line change
@@ -80,17 +80,22 @@ class OTACloudProcessInterface: public CloudProcess {
8080
enum OtaFlags: uint16_t {
8181
None = 0,
8282
ApprovalRequired = 1,
83-
Approved = 1<<1
83+
Approved = 1<<1,
84+
ChunkDownload = 1<<2
8485
};
8586

8687
virtual void handleMessage(Message*);
8788
// virtual CloudProcess::State getState();
8889
// virtual void hook(State s, void* action);
8990
virtual void update() { handleMessage(nullptr); }
9091

91-
inline void approveOta() { policies |= Approved; }
92+
inline void approveOta() { this->policies |= Approved; }
9293
inline void setOtaPolicies(uint16_t policies) { this->policies = policies; }
9394

95+
inline void enableOtaPolicy(OtaFlags policyFlag) { this->policies |= policyFlag; }
96+
inline void disableOtaPolicy(OtaFlags policyFlag) { this->policies &= ~policyFlag; }
97+
inline bool getOtaPolicy(OtaFlags policyFlag) { return (this->policies & policyFlag) != 0;}
98+
9499
inline State getState() { return state; }
95100

96101
virtual bool isOtaCapable() = 0;

src/ota/interface/OTAInterfaceDefault.cpp

+89-41
Original file line numberDiff line numberDiff line change
@@ -41,39 +41,17 @@ OTACloudProcessInterface::State OTADefaultCloudProcessInterface::startOTA() {
4141
}
4242
);
4343

44-
// make the http get request
44+
// check url
4545
if(strcmp(context->parsed_url.schema(), "https") == 0) {
4646
http_client = new HttpClient(*client, context->parsed_url.host(), context->parsed_url.port());
4747
} else {
4848
return UrlParseErrorFail;
4949
}
5050

51-
http_client->beginRequest();
52-
auto res = http_client->get(context->parsed_url.path());
53-
54-
if(username != nullptr && password != nullptr) {
55-
http_client->sendBasicAuth(username, password);
56-
}
57-
58-
http_client->endRequest();
59-
60-
if(res == HTTP_ERROR_CONNECTION_FAILED) {
61-
DEBUG_VERBOSE("OTA ERROR: http client error connecting to server \"%s:%d\"",
62-
context->parsed_url.host(), context->parsed_url.port());
63-
return ServerConnectErrorFail;
64-
} else if(res == HTTP_ERROR_TIMED_OUT) {
65-
DEBUG_VERBOSE("OTA ERROR: http client timeout \"%s\"", OTACloudProcessInterface::context->url);
66-
return OtaHeaderTimeoutFail;
67-
} else if(res != HTTP_SUCCESS) {
68-
DEBUG_VERBOSE("OTA ERROR: http client returned %d on get \"%s\"", res, OTACloudProcessInterface::context->url);
69-
return OtaDownloadFail;
70-
}
71-
72-
int statusCode = http_client->responseStatusCode();
73-
74-
if(statusCode != 200) {
75-
DEBUG_VERBOSE("OTA ERROR: get response on \"%s\" returned status %d", OTACloudProcessInterface::context->url, statusCode);
76-
return HttpResponseFail;
51+
// make the http get request
52+
OTACloudProcessInterface::State res = requestOta();
53+
if(res != Fetch) {
54+
return res;
7755
}
7856

7957
// The following call is required to save the header value , keep it
@@ -82,16 +60,27 @@ OTACloudProcessInterface::State OTADefaultCloudProcessInterface::startOTA() {
8260
return HttpHeaderErrorFail;
8361
}
8462

63+
context->contentLength = http_client->contentLength();
8564
context->lastReportTime = millis();
86-
65+
DEBUG_VERBOSE("OTA file length: %d", context->contentLength);
8766
return Fetch;
8867
}
8968

9069
OTACloudProcessInterface::State OTADefaultCloudProcessInterface::fetch() {
9170
OTACloudProcessInterface::State res = Fetch;
92-
int http_res = 0;
93-
uint32_t start = millis();
9471

72+
if(getOtaPolicy(ChunkDownload)) {
73+
res = requestOta(ChunkDownload);
74+
}
75+
76+
context->downloadedChunkSize = 0;
77+
context->downloadedChunkStartTime = millis();
78+
79+
if(res != Fetch) {
80+
goto exit;
81+
}
82+
83+
/* download chunked or timed */
9584
do {
9685
if(!http_client->connected()) {
9786
res = OtaDownloadFail;
@@ -104,7 +93,7 @@ OTACloudProcessInterface::State OTADefaultCloudProcessInterface::fetch() {
10493
continue;
10594
}
10695

107-
http_res = http_client->read(context->buffer, context->buf_len);
96+
int http_res = http_client->read(context->buffer, context->bufLen);
10897

10998
if(http_res < 0) {
11099
DEBUG_VERBOSE("OTA ERROR: Download read error %d", http_res);
@@ -119,8 +108,10 @@ OTACloudProcessInterface::State OTADefaultCloudProcessInterface::fetch() {
119108
res = ErrorWriteUpdateFileFail;
120109
goto exit;
121110
}
122-
} while((context->downloadState == OtaDownloadFile || context->downloadState == OtaDownloadHeader) &&
123-
millis() - start < downloadTime);
111+
112+
context->downloadedChunkSize += http_res;
113+
114+
} while(context->downloadState < OtaDownloadCompleted && fetchMore());
124115

125116
// TODO verify that the information present in the ota header match the info in context
126117
if(context->downloadState == OtaDownloadCompleted) {
@@ -153,13 +144,69 @@ OTACloudProcessInterface::State OTADefaultCloudProcessInterface::fetch() {
153144
return res;
154145
}
155146

156-
void OTADefaultCloudProcessInterface::parseOta(uint8_t* buffer, size_t buf_len) {
147+
OTACloudProcessInterface::State OTADefaultCloudProcessInterface::requestOta(OtaFlags mode) {
148+
int http_res = 0;
149+
150+
/* stop connected client */
151+
http_client->stop();
152+
153+
/* request chunk */
154+
http_client->beginRequest();
155+
http_res = http_client->get(context->parsed_url.path());
156+
157+
if(username != nullptr && password != nullptr) {
158+
http_client->sendBasicAuth(username, password);
159+
}
160+
161+
if((mode & ChunkDownload) == ChunkDownload) {
162+
char range[128] = {0};
163+
size_t rangeSize = context->downloadedSize + maxChunkSize > context->contentLength ? context->contentLength - context->downloadedSize : maxChunkSize;
164+
sprintf(range, "bytes=%" PRIu32 "-%" PRIu32, context->downloadedSize, context->downloadedSize + rangeSize);
165+
DEBUG_VERBOSE("OTA downloading range: %s", range);
166+
http_client->sendHeader("Range", range);
167+
}
168+
169+
http_client->endRequest();
170+
171+
if(http_res == HTTP_ERROR_CONNECTION_FAILED) {
172+
DEBUG_VERBOSE("OTA ERROR: http client error connecting to server \"%s:%d\"",
173+
context->parsed_url.host(), context->parsed_url.port());
174+
return ServerConnectErrorFail;
175+
} else if(http_res == HTTP_ERROR_TIMED_OUT) {
176+
DEBUG_VERBOSE("OTA ERROR: http client timeout \"%s\"", OTACloudProcessInterface::context->url);
177+
return OtaHeaderTimeoutFail;
178+
} else if(http_res != HTTP_SUCCESS) {
179+
DEBUG_VERBOSE("OTA ERROR: http client returned %d on get \"%s\"", http_res, OTACloudProcessInterface::context->url);
180+
return OtaDownloadFail;
181+
}
182+
183+
int statusCode = http_client->responseStatusCode();
184+
185+
if((((mode & ChunkDownload) == ChunkDownload) && (statusCode != 206)) ||
186+
(((mode & ChunkDownload) != ChunkDownload) && (statusCode != 200))) {
187+
DEBUG_VERBOSE("OTA ERROR: get response on \"%s\" returned status %d", OTACloudProcessInterface::context->url, statusCode);
188+
return HttpResponseFail;
189+
}
190+
191+
http_client->skipResponseHeaders();
192+
return Fetch;
193+
}
194+
195+
bool OTADefaultCloudProcessInterface::fetchMore() {
196+
if (getOtaPolicy(ChunkDownload)) {
197+
return context->downloadedChunkSize < maxChunkSize;
198+
} else {
199+
return (millis() - context->downloadedChunkStartTime) < downloadTime;
200+
}
201+
}
202+
203+
void OTADefaultCloudProcessInterface::parseOta(uint8_t* buffer, size_t bufLen) {
157204
assert(context != nullptr); // This should never fail
158205

159-
for(uint8_t* cursor=(uint8_t*)buffer; cursor<buffer+buf_len; ) {
206+
for(uint8_t* cursor=(uint8_t*)buffer; cursor<buffer+bufLen; ) {
160207
switch(context->downloadState) {
161208
case OtaDownloadHeader: {
162-
const uint32_t headerLeft = context->headerCopiedBytes + buf_len <= sizeof(context->header.buf) ? buf_len : sizeof(context->header.buf) - context->headerCopiedBytes;
209+
const uint32_t headerLeft = context->headerCopiedBytes + bufLen <= sizeof(context->header.buf) ? bufLen : sizeof(context->header.buf) - context->headerCopiedBytes;
163210
memcpy(context->header.buf+context->headerCopiedBytes, buffer, headerLeft);
164211
cursor += headerLeft;
165212
context->headerCopiedBytes += headerLeft;
@@ -184,8 +231,7 @@ void OTADefaultCloudProcessInterface::parseOta(uint8_t* buffer, size_t buf_len)
184231
break;
185232
}
186233
case OtaDownloadFile: {
187-
const uint32_t contentLength = http_client->contentLength();
188-
const uint32_t dataLeft = buf_len - (cursor-buffer);
234+
const uint32_t dataLeft = bufLen - (cursor-buffer);
189235
context->decoder.decompress(cursor, dataLeft); // TODO verify return value
190236

191237
context->calculatedCrc32 = crc_update(
@@ -198,18 +244,18 @@ void OTADefaultCloudProcessInterface::parseOta(uint8_t* buffer, size_t buf_len)
198244
context->downloadedSize += dataLeft;
199245

200246
if((millis() - context->lastReportTime) > 10000) { // Report the download progress each X millisecond
201-
DEBUG_VERBOSE("OTA Download Progress %d/%d", context->downloadedSize, contentLength);
247+
DEBUG_VERBOSE("OTA Download Progress %d/%d", context->downloadedSize, context->contentLength);
202248

203249
reportStatus(context->downloadedSize);
204250
context->lastReportTime = millis();
205251
}
206252

207253
// TODO there should be no more bytes available when the download is completed
208-
if(context->downloadedSize == contentLength) {
254+
if(context->downloadedSize == context->contentLength) {
209255
context->downloadState = OtaDownloadCompleted;
210256
}
211257

212-
if(context->downloadedSize > contentLength) {
258+
if(context->downloadedSize > context->contentLength) {
213259
context->downloadState = OtaDownloadError;
214260
}
215261
// TODO fail if we exceed a timeout? and available is 0 (client is broken)
@@ -250,7 +296,9 @@ OTADefaultCloudProcessInterface::Context::Context(
250296
, headerCopiedBytes(0)
251297
, downloadedSize(0)
252298
, lastReportTime(0)
299+
, contentLength(0)
253300
, writeError(false)
301+
, downloadedChunkSize(0)
254302
, decoder(putc) { }
255303

256304
static const uint32_t crc_table[256] = {

src/ota/interface/OTAInterfaceDefault.h

+13-3
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,9 @@ class OTADefaultCloudProcessInterface: public OTACloudProcessInterface {
4242
virtual int writeFlash(uint8_t* const buffer, size_t len) = 0;
4343

4444
private:
45-
void parseOta(uint8_t* buffer, size_t buf_len);
45+
void parseOta(uint8_t* buffer, size_t bufLen);
46+
State requestOta(OtaFlags mode = None);
47+
bool fetchMore();
4648

4749
Client* client;
4850
HttpClient* http_client;
@@ -53,6 +55,10 @@ class OTADefaultCloudProcessInterface: public OTACloudProcessInterface {
5355
// This mitigate the issues arising from tasks run in main loop that are using all the computing time
5456
static constexpr uint32_t downloadTime = 2000;
5557

58+
// The amount of data that each iteration of Fetch has to take at least
59+
// This should be enabled setting ChunkDownload OtaFlag to 1 and mitigate some Ota corner cases
60+
static constexpr size_t maxChunkSize = 1024 * 10;
61+
5662
enum OTADownloadState: uint8_t {
5763
OtaDownloadHeader,
5864
OtaDownloadFile,
@@ -74,13 +80,17 @@ class OTADefaultCloudProcessInterface: public OTACloudProcessInterface {
7480
uint32_t headerCopiedBytes;
7581
uint32_t downloadedSize;
7682
uint32_t lastReportTime;
83+
uint32_t contentLength;
7784
bool writeError;
7885

86+
uint32_t downloadedChunkStartTime;
87+
uint32_t downloadedChunkSize;
88+
7989
// LZSS decoder
8090
LZSSDecoder decoder;
8191

82-
const size_t buf_len = 64;
83-
uint8_t buffer[64];
92+
static constexpr size_t bufLen = 64;
93+
uint8_t buffer[bufLen];
8494
} *context;
8595
};
8696

0 commit comments

Comments
 (0)