Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit bc2dc4b

Browse files
committedAug 17, 2024·
feat: Middleware support for WebServer
Moving default middlewares (cors, authc, logging) to examples
1 parent def319a commit bc2dc4b

File tree

9 files changed

+729
-85
lines changed

9 files changed

+729
-85
lines changed
 
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
#include <WiFi.h>
2+
#include <WebServer.h>
3+
#include <Middlewares.h>
4+
5+
// Your AP WiFi Credentials
6+
// ( This is the AP your ESP will broadcast )
7+
const char *ap_ssid = "ESP32_Demo";
8+
const char *ap_password = "";
9+
10+
WebServer server(80);
11+
12+
LoggingMiddleware logger(Serial);
13+
CorsMiddleware cors;
14+
AuthenticationMiddleware auth;
15+
16+
void setup(void) {
17+
Serial.begin(115200);
18+
WiFi.softAP(ap_ssid, ap_password);
19+
20+
Serial.print("IP address: ");
21+
Serial.println(WiFi.AP.localIP());
22+
23+
cors.origin("http://192.168.4.1");
24+
cors.methods("POST, GET, OPTIONS, DELETE");
25+
cors.headers("X-Custom-Header");
26+
cors.allowCredentials(false);
27+
cors.maxAge(600);
28+
29+
auth.authenticate("admin", "admin");
30+
31+
server
32+
.on(
33+
"/",
34+
[]() {
35+
server.send(200, "text/plain", "Home");
36+
}
37+
)
38+
.addMiddleware(&logger)
39+
.addMiddleware(&cors)
40+
.addMiddleware(&auth);
41+
42+
server.onNotFound([]() {
43+
server.send(404, "text/plain", "Page not found");
44+
});
45+
46+
server.collectAllHeaders();
47+
server.begin();
48+
Serial.println("HTTP server started");
49+
}
50+
51+
void loop(void) {
52+
server.handleClient();
53+
delay(2); //allow the cpu to switch to other tasks
54+
}
Lines changed: 208 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,208 @@
1+
#ifndef MIDDLEWARES_H
2+
#define MIDDLEWARES_H
3+
4+
#include <WebServer.h>
5+
#include <Stream.h>
6+
7+
#include <assert.h>
8+
9+
// curl-like logging middleware
10+
class LoggingMiddleware : public Middleware {
11+
private:
12+
Stream *_out;
13+
14+
public:
15+
explicit LoggingMiddleware(Stream &out) : _out(&out) {}
16+
17+
bool run(WebServer &server, Middleware::Callback next) override {
18+
_out->print(F("* Connection from "));
19+
_out->print(server.client().remoteIP().toString());
20+
_out->print(F(":"));
21+
_out->println(server.client().remotePort());
22+
23+
_out->print(F("< "));
24+
HTTPMethod method = server.method();
25+
if (method == HTTP_ANY) {
26+
_out->print(F("HTTP_ANY"));
27+
} else {
28+
_out->print(http_method_str((http_method)method));
29+
}
30+
_out->print(F(" "));
31+
_out->print(server.uri());
32+
_out->print(F(" "));
33+
_out->println(server.version());
34+
35+
int n = server.headers();
36+
for (int i = 0; i < n; i++) {
37+
String v = server.header(i);
38+
if (!v.isEmpty()) {
39+
// because these 2 are always there, eventually empty: "Authorization", "If-None-Match"
40+
_out->print(F("< "));
41+
_out->print(server.headerName(i));
42+
_out->print(F(": "));
43+
_out->println(server.header(i));
44+
}
45+
}
46+
47+
_out->println(F("<"));
48+
49+
bool ret = next();
50+
51+
if (ret) {
52+
_out->println(F("* Processed!"));
53+
54+
_out->print(F("> "));
55+
_out->print(F("HTTP/1."));
56+
_out->print(server.version());
57+
_out->print(F(" "));
58+
_out->print(server.responseCode());
59+
_out->print(F(" "));
60+
_out->println(WebServer::responseCodeToString(server.responseCode()));
61+
62+
n = server.responseHeaders();
63+
for (int i = 0; i < n; i++) {
64+
_out->print(F("> "));
65+
_out->print(server.responseHeaderName(i));
66+
_out->print(F(": "));
67+
_out->println(server.responseHeader(i));
68+
}
69+
70+
_out->println(F(">"));
71+
72+
} else {
73+
_out->println(F("* Not processed!"));
74+
}
75+
76+
return ret;
77+
}
78+
};
79+
80+
class AuthenticationMiddleware : public Middleware {
81+
private:
82+
// authenticate state
83+
// 0: not authenticated
84+
// 1: callback
85+
// 2: username/password
86+
// 3: sha1
87+
int _auth = 0;
88+
WebServer::THandlerFunctionAuthCheck _fn;
89+
String _username;
90+
String _password;
91+
String _sha1;
92+
93+
// authenticate request
94+
HTTPAuthMethod _mode = BASIC_AUTH;
95+
String _realm;
96+
String _authFailMsg;
97+
98+
public:
99+
AuthenticationMiddleware &authenticate(WebServer::THandlerFunctionAuthCheck fn) {
100+
assert(fn);
101+
_fn = fn;
102+
_auth = 1;
103+
return *this;
104+
}
105+
106+
AuthenticationMiddleware &authenticate(const char *username, const char *password) {
107+
if (strlen(username) == 0 || strlen(password) == 0) {
108+
_auth = 0;
109+
return *this;
110+
} else {
111+
_username = username;
112+
_password = password;
113+
_auth = 2;
114+
return *this;
115+
}
116+
}
117+
118+
AuthenticationMiddleware &authenticateBasicSHA1(const char *username, const char *sha1AsBase64orHex) {
119+
if (strlen(username) == 0 || strlen(sha1AsBase64orHex) == 0) {
120+
_auth = 0;
121+
return *this;
122+
}
123+
_username = username;
124+
_sha1 = sha1AsBase64orHex;
125+
_auth = 3;
126+
return *this;
127+
}
128+
129+
bool run(WebServer &server, Middleware::Callback next) override {
130+
switch (_auth) {
131+
case 1:
132+
if (server.authenticate(_fn)) {
133+
return next();
134+
} else {
135+
server.requestAuthentication(_mode, _realm.c_str(), _authFailMsg);
136+
return true;
137+
}
138+
139+
case 2:
140+
if (server.authenticate(_username.c_str(), _password.c_str())) {
141+
return next();
142+
} else {
143+
server.requestAuthentication(_mode, _realm.c_str(), _authFailMsg);
144+
return true;
145+
}
146+
147+
case 3:
148+
if (server.authenticate(_username.c_str(), _sha1.c_str())) {
149+
return next();
150+
} else {
151+
server.requestAuthentication(_mode, _realm.c_str(), _authFailMsg);
152+
return true;
153+
}
154+
155+
default: return next();
156+
}
157+
}
158+
};
159+
160+
class CorsMiddleware : public Middleware {
161+
private:
162+
String _origin = F("*");
163+
String _methods = F("*");
164+
String _headers = F("*");
165+
bool _credentials = true;
166+
uint32_t _maxAge = 86400;
167+
168+
public:
169+
CorsMiddleware &origin(const char *origin) {
170+
_origin = origin;
171+
return *this;
172+
}
173+
174+
CorsMiddleware &methods(const char *methods) {
175+
_methods = methods;
176+
return *this;
177+
}
178+
179+
CorsMiddleware &headers(const char *headers) {
180+
_headers = headers;
181+
return *this;
182+
}
183+
184+
CorsMiddleware &allowCredentials(bool credentials) {
185+
_credentials = credentials;
186+
return *this;
187+
}
188+
189+
CorsMiddleware &maxAge(uint32_t seconds) {
190+
_maxAge = seconds;
191+
return *this;
192+
}
193+
194+
bool run(WebServer &server, Middleware::Callback next) override {
195+
if (server.method() == HTTP_OPTIONS) {
196+
server.sendHeader(F("Access-Control-Allow-Origin"), _origin.c_str());
197+
server.sendHeader(F("Access-Control-Allow-Methods"), _methods.c_str());
198+
server.sendHeader(F("Access-Control-Allow-Headers"), _headers.c_str());
199+
server.sendHeader(F("Access-Control-Allow-Credentials"), _credentials ? F("true") : F("false"));
200+
server.sendHeader(F("Access-Control-Max-Age"), String(_maxAge).c_str());
201+
server.send(200);
202+
return true;
203+
}
204+
return next();
205+
}
206+
};
207+
208+
#endif
Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
This example shows how to load all request headers and use middleware.
2+
3+
### CORS Middleware
4+
5+
```bash
6+
❯ curl -i -X OPTIONS http://192.168.4.1
7+
HTTP/1.1 200 OK
8+
Content-Type: text/html
9+
Access-Control-Allow-Origin: http://192.168.4.1
10+
Access-Control-Allow-Methods: POST, GET, OPTIONS, DELETE
11+
Access-Control-Allow-Headers: X-Custom-Header
12+
Access-Control-Allow-Credentials: false
13+
Access-Control-Max-Age: 600
14+
Content-Length: 0
15+
Connection: close
16+
```
17+
18+
Output of logger middleware:
19+
20+
```
21+
* Connection from 192.168.4.2:57597
22+
< OPTIONS / HTTP/1.1
23+
< Host: 192.168.4.1
24+
< User-Agent: curl/8.9.1
25+
< Accept: */*
26+
<
27+
* Processed!
28+
> HTTP/1.HTTP/1.1 200 OK
29+
> Content-Type: text/html
30+
> Access-Control-Allow-Origin: http://192.168.4.1
31+
> Access-Control-Allow-Methods: POST, GET, OPTIONS, DELETE
32+
> Access-Control-Allow-Headers: X-Custom-Header
33+
> Access-Control-Allow-Credentials: false
34+
> Access-Control-Max-Age: 600
35+
> Content-Length: 0
36+
> Connection: close
37+
>
38+
```
39+
40+
### Authentication Middleware
41+
42+
```bash
43+
❯ curl -i -X GET http://192.168.4.1
44+
HTTP/1.1 401 Unauthorized
45+
Content-Type: text/html
46+
WWW-Authenticate: Basic realm=""
47+
Content-Length: 0
48+
Connection: close
49+
```
50+
51+
Output of logger middleware:
52+
53+
```
54+
* Connection from 192.168.4.2:57705
55+
< GET / HTTP/1.1
56+
< Host: 192.168.4.1
57+
< User-Agent: curl/8.9.1
58+
< Accept: */*
59+
<
60+
* Processed!
61+
> HTTP/1.HTTP/1.1 401 Unauthorized
62+
> Content-Type: text/html
63+
> WWW-Authenticate: Basic realm=""
64+
> Content-Length: 0
65+
> Connection: close
66+
>
67+
```
68+
69+
Sending auth...
70+
71+
```bash
72+
Note: Unnecessary use of -X or --request, GET is already inferred.
73+
* Trying 192.168.4.1:80...
74+
* Connected to 192.168.4.1 (192.168.4.1) port 80
75+
* Server auth using Basic with user 'admin'
76+
> GET / HTTP/1.1
77+
> Host: 192.168.4.1
78+
> Authorization: Basic YWRtaW46YWRtaW4=
79+
> User-Agent: curl/8.9.1
80+
> Accept: */*
81+
>
82+
* Request completely sent off
83+
< HTTP/1.1 200 OK
84+
< Content-Type: text/plain
85+
< Content-Length: 4
86+
< Connection: close
87+
<
88+
* shutting down connection #0
89+
Home
90+
91+
```
92+
93+
Output of logger middleware:
94+
95+
```
96+
* Connection from 192.168.4.2:62099
97+
< GET / HTTP/1.1
98+
< Authorization: Basic YWRtaW46YWRtaW4=
99+
< Host: 192.168.4.1
100+
< User-Agent: curl/8.9.1
101+
< Accept: */*
102+
<
103+
* Processed!
104+
> HTTP/1.HTTP/1.1 200 OK
105+
> Content-Type: text/plain
106+
> Content-Length: 4
107+
> Connection: close
108+
>
109+
```
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
{
2+
"targets": {
3+
"esp32h2": false
4+
}
5+
}

‎libraries/WebServer/src/Parsing.cpp

Lines changed: 28 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -78,8 +78,14 @@ bool WebServer::_parseRequest(NetworkClient &client) {
7878
String req = client.readStringUntil('\r');
7979
client.readStringUntil('\n');
8080
//reset header value
81-
for (int i = 0; i < _headerKeysCount; ++i) {
82-
_currentHeaders[i].value = String();
81+
if (_collectAllHeaders) {
82+
// clear previous headers
83+
collectAllHeaders();
84+
} else {
85+
// clear previous headers
86+
for (RequestArgument *header = _currentHeaders; header; header = header->next) {
87+
header->value = String();
88+
}
8389
}
8490

8591
// First line of HTTP request looks like "GET /path HTTP/1.1"
@@ -154,9 +160,6 @@ bool WebServer::_parseRequest(NetworkClient &client) {
154160
headerValue.trim();
155161
_collectHeader(headerName.c_str(), headerValue.c_str());
156162

157-
log_v("headerName: %s", headerName.c_str());
158-
log_v("headerValue: %s", headerValue.c_str());
159-
160163
if (headerName.equalsIgnoreCase(FPSTR(Content_Type))) {
161164
using namespace mime;
162165
if (headerValue.startsWith(FPSTR(mimeTable[txt].mimeType))) {
@@ -253,9 +256,6 @@ bool WebServer::_parseRequest(NetworkClient &client) {
253256
headerValue = req.substring(headerDiv + 2);
254257
_collectHeader(headerName.c_str(), headerValue.c_str());
255258

256-
log_v("headerName: %s", headerName.c_str());
257-
log_v("headerValue: %s", headerValue.c_str());
258-
259259
if (headerName.equalsIgnoreCase("Host")) {
260260
_hostHeader = headerValue;
261261
}
@@ -271,12 +271,29 @@ bool WebServer::_parseRequest(NetworkClient &client) {
271271
}
272272

273273
bool WebServer::_collectHeader(const char *headerName, const char *headerValue) {
274-
for (int i = 0; i < _headerKeysCount; i++) {
275-
if (_currentHeaders[i].key.equalsIgnoreCase(headerName)) {
276-
_currentHeaders[i].value = headerValue;
274+
RequestArgument *last = nullptr;
275+
for (RequestArgument *header = _currentHeaders; header; header = header->next) {
276+
if (header->next == nullptr) {
277+
last = header;
278+
}
279+
if (header->key.equalsIgnoreCase(headerName)) {
280+
header->value = headerValue;
281+
log_v("header collected: %s: %s", headerName, headerValue);
277282
return true;
278283
}
279284
}
285+
assert(last);
286+
if (_collectAllHeaders) {
287+
last->next = new RequestArgument();
288+
last->next->key = headerName;
289+
last->next->value = headerValue;
290+
_headerKeysCount++;
291+
log_v("header collected: %s: %s", headerName, headerValue);
292+
return true;
293+
}
294+
295+
log_v("header skipped: %s: %s", headerName, headerValue);
296+
280297
return false;
281298
}
282299

‎libraries/WebServer/src/WebServer.cpp

Lines changed: 121 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -41,31 +41,27 @@ static const char WWW_Authenticate[] = "WWW-Authenticate";
4141
static const char Content_Length[] = "Content-Length";
4242
static const char ETAG_HEADER[] = "If-None-Match";
4343

44-
WebServer::WebServer(IPAddress addr, int port)
45-
: _corsEnabled(false), _server(addr, port), _currentMethod(HTTP_ANY), _currentVersion(0), _currentStatus(HC_NONE), _statusChange(0), _nullDelay(true),
46-
_currentHandler(nullptr), _firstHandler(nullptr), _lastHandler(nullptr), _currentArgCount(0), _currentArgs(nullptr), _postArgsLen(0), _postArgs(nullptr),
47-
_headerKeysCount(0), _currentHeaders(nullptr), _contentLength(0), _clientContentLength(0), _chunked(false) {
44+
WebServer::WebServer(IPAddress addr, int port) : _server(addr, port) {
4845
log_v("WebServer::Webserver(addr=%s, port=%d)", addr.toString().c_str(), port);
4946
}
5047

51-
WebServer::WebServer(int port)
52-
: _corsEnabled(false), _server(port), _currentMethod(HTTP_ANY), _currentVersion(0), _currentStatus(HC_NONE), _statusChange(0), _nullDelay(true),
53-
_currentHandler(nullptr), _firstHandler(nullptr), _lastHandler(nullptr), _currentArgCount(0), _currentArgs(nullptr), _postArgsLen(0), _postArgs(nullptr),
54-
_headerKeysCount(0), _currentHeaders(nullptr), _contentLength(0), _clientContentLength(0), _chunked(false) {
48+
WebServer::WebServer(int port) : _server(port) {
5549
log_v("WebServer::Webserver(port=%d)", port);
5650
}
5751

5852
WebServer::~WebServer() {
5953
_server.close();
60-
if (_currentHeaders) {
61-
delete[] _currentHeaders;
62-
}
54+
55+
_clearRequestHeaders();
56+
_clearResponseHeaders();
57+
6358
RequestHandler *handler = _firstHandler;
6459
while (handler) {
6560
RequestHandler *next = handler->next();
6661
delete handler;
6762
handler = next;
6863
}
64+
_firstHandler = nullptr;
6965
}
7066

7167
void WebServer::begin() {
@@ -435,6 +431,8 @@ void WebServer::handleClient() {
435431
_currentClient.setTimeout(HTTP_MAX_SEND_WAIT); /* / 1000 removed, WifiClient setTimeout changed to ms */
436432
if (_parseRequest(_currentClient)) {
437433
_contentLength = CONTENT_LENGTH_NOT_SET;
434+
_responseCode = 0;
435+
_clearResponseHeaders();
438436
_handleRequest();
439437

440438
if (_currentClient.isSSE()) {
@@ -494,16 +492,22 @@ void WebServer::stop() {
494492
}
495493

496494
void WebServer::sendHeader(const String &name, const String &value, bool first) {
497-
String headerLine = name;
498-
headerLine += F(": ");
499-
headerLine += value;
500-
headerLine += "\r\n";
495+
RequestArgument *header = new RequestArgument();
496+
header->key = name;
497+
header->value = value;
501498

502-
if (first) {
503-
_responseHeaders = headerLine + _responseHeaders;
499+
if (!_responseHeaders || first) {
500+
header->next = _responseHeaders;
501+
_responseHeaders = header;
504502
} else {
505-
_responseHeaders += headerLine;
503+
RequestArgument *last = _responseHeaders;
504+
while (last->next) {
505+
last = last->next;
506+
}
507+
last->next = header;
506508
}
509+
510+
_responseHeaderCount++;
507511
}
508512

509513
void WebServer::setContentLength(const size_t contentLength) {
@@ -528,10 +532,13 @@ void WebServer::enableETag(bool enable, ETagFunction fn) {
528532
}
529533

530534
void WebServer::_prepareHeader(String &response, int code, const char *content_type, size_t contentLength) {
531-
response = String(F("HTTP/1.")) + String(_currentVersion) + ' ';
535+
_responseCode = code;
536+
537+
response = version();
538+
response += ' ';
532539
response += String(code);
533540
response += ' ';
534-
response += _responseCodeToString(code);
541+
response += responseCodeToString(code);
535542
response += "\r\n";
536543

537544
using namespace mime;
@@ -557,19 +564,21 @@ void WebServer::_prepareHeader(String &response, int code, const char *content_t
557564
}
558565
sendHeader(String(F("Connection")), String(F("close")));
559566

560-
response += _responseHeaders;
561-
response += "\r\n";
562-
_responseHeaders = "";
567+
for (RequestArgument *header = _responseHeaders; header; header = header->next) {
568+
response += header->key;
569+
response += F(": ");
570+
response += header->value;
571+
response += F("\r\n");
572+
}
573+
574+
response += F("\r\n");
563575
}
564576

565577
void WebServer::send(int code, const char *content_type, const String &content) {
566578
String header;
567579
// Can we assume the following?
568580
//if(code == 200 && content.length() == 0 && _contentLength == CONTENT_LENGTH_NOT_SET)
569581
// _contentLength = CONTENT_LENGTH_UNKNOWN;
570-
if (content.length() == 0) {
571-
log_w("content length is zero");
572-
}
573582
_prepareHeader(header, code, content_type, content.length());
574583
_currentClientWrite(header.c_str(), header.length());
575584
if (content.length()) {
@@ -727,52 +736,93 @@ bool WebServer::hasArg(String name) {
727736
}
728737

729738
String WebServer::header(String name) {
730-
for (int i = 0; i < _headerKeysCount; ++i) {
731-
if (_currentHeaders[i].key.equalsIgnoreCase(name)) {
732-
return _currentHeaders[i].value;
739+
for (RequestArgument *current = _currentHeaders; current; current = current->next) {
740+
if (current->key.equalsIgnoreCase(name)) {
741+
return current->value;
733742
}
734743
}
735744
return "";
736745
}
737746

738747
void WebServer::collectHeaders(const char *headerKeys[], const size_t headerKeysCount) {
739-
_headerKeysCount = headerKeysCount + 2;
740-
if (_currentHeaders) {
741-
delete[] _currentHeaders;
742-
}
743-
_currentHeaders = new RequestArgument[_headerKeysCount];
744-
_currentHeaders[0].key = FPSTR(AUTHORIZATION_HEADER);
745-
_currentHeaders[1].key = FPSTR(ETAG_HEADER);
748+
collectAllHeaders();
749+
_collectAllHeaders = false;
750+
751+
_headerKeysCount += headerKeysCount;
752+
753+
RequestArgument *last = _currentHeaders->next;
754+
746755
for (int i = 2; i < _headerKeysCount; i++) {
747-
_currentHeaders[i].key = headerKeys[i - 2];
756+
last->next = new RequestArgument();
757+
last->next->key = headerKeys[i - 2];
758+
last = last->next;
748759
}
749760
}
750761

762+
void WebServer::collectAllHeaders() {
763+
_clearRequestHeaders();
764+
765+
_currentHeaders = new RequestArgument();
766+
_currentHeaders->key = FPSTR(AUTHORIZATION_HEADER);
767+
768+
_currentHeaders->next = new RequestArgument();
769+
_currentHeaders->next->key = FPSTR(ETAG_HEADER);
770+
771+
_headerKeysCount = 2;
772+
_collectAllHeaders = true;
773+
}
774+
751775
String WebServer::header(int i) {
752-
if (i < _headerKeysCount) {
753-
return _currentHeaders[i].value;
776+
RequestArgument *current = _currentHeaders;
777+
while (current && i--) {
778+
current = current->next;
754779
}
755-
return "";
780+
return current ? current->value : emptyString;
756781
}
757782

758783
String WebServer::headerName(int i) {
759-
if (i < _headerKeysCount) {
760-
return _currentHeaders[i].key;
784+
RequestArgument *current = _currentHeaders;
785+
while (current && i--) {
786+
current = current->next;
761787
}
762-
return "";
788+
return current ? current->key : emptyString;
763789
}
764790

765791
int WebServer::headers() {
766792
return _headerKeysCount;
767793
}
768794

769795
bool WebServer::hasHeader(String name) {
770-
for (int i = 0; i < _headerKeysCount; ++i) {
771-
if ((_currentHeaders[i].key.equalsIgnoreCase(name)) && (_currentHeaders[i].value.length() > 0)) {
772-
return true;
796+
return header(name).length() > 0;
797+
}
798+
799+
const String &WebServer::responseHeader(String name) {
800+
for (RequestArgument *current = _responseHeaders; current; current = current->next) {
801+
if (current->key.equalsIgnoreCase(name)) {
802+
return current->value;
773803
}
774804
}
775-
return false;
805+
return emptyString;
806+
}
807+
808+
const String &WebServer::responseHeader(int i) {
809+
RequestArgument *current = _responseHeaders;
810+
while (current && i--) {
811+
current = current->next;
812+
}
813+
return current ? current->value : emptyString;
814+
}
815+
816+
const String &WebServer::responseHeaderName(int i) {
817+
RequestArgument *current = _responseHeaders;
818+
while (current && i--) {
819+
current = current->next;
820+
}
821+
return current ? current->key : emptyString;
822+
}
823+
824+
bool WebServer::hasResponseHeader(String name) {
825+
return header(name).length() > 0;
776826
}
777827

778828
String WebServer::hostHeader() {
@@ -792,7 +842,7 @@ void WebServer::_handleRequest() {
792842
if (!_currentHandler) {
793843
log_e("request handler not found");
794844
} else {
795-
handled = _currentHandler->handle(*this, _currentMethod, _currentUri);
845+
handled = _currentHandler->process(*this, _currentMethod, _currentUri);
796846
if (!handled) {
797847
log_e("request handler failed to handle request");
798848
}
@@ -818,7 +868,29 @@ void WebServer::_finalizeResponse() {
818868
}
819869
}
820870

821-
String WebServer::_responseCodeToString(int code) {
871+
void WebServer::_clearResponseHeaders() {
872+
_responseHeaderCount = 0;
873+
RequestArgument *current = _responseHeaders;
874+
while (current) {
875+
RequestArgument *next = current->next;
876+
delete current;
877+
current = next;
878+
}
879+
_responseHeaders = nullptr;
880+
}
881+
882+
void WebServer::_clearRequestHeaders() {
883+
_headerKeysCount = 0;
884+
RequestArgument *current = _currentHeaders;
885+
while (current) {
886+
RequestArgument *next = current->next;
887+
delete current;
888+
current = next;
889+
}
890+
_currentHeaders = nullptr;
891+
}
892+
893+
String WebServer::responseCodeToString(int code) {
822894
switch (code) {
823895
case 100: return F("Continue");
824896
case 101: return F("Switching Protocols");

‎libraries/WebServer/src/WebServer.h

Lines changed: 53 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ typedef struct {
9292
void *data; // additional data
9393
} HTTPRaw;
9494

95+
#include "detail/Middleware.h"
9596
#include "detail/RequestHandler.h"
9697

9798
namespace fs {
@@ -181,11 +182,31 @@ class WebServer {
181182
int args(); // get arguments count
182183
bool hasArg(String name); // check if argument exists
183184
void collectHeaders(const char *headerKeys[], const size_t headerKeysCount); // set the request headers to collect
185+
void collectAllHeaders(); // collect all request headers
184186
String header(String name); // get request header value by name
185187
String header(int i); // get request header value by number
186188
String headerName(int i); // get request header name by number
187189
int headers(); // get header count
188190
bool hasHeader(String name); // check if header exists
191+
const String version() const {
192+
String v;
193+
v.reserve(8);
194+
v.concat(F("HTTP/1."));
195+
v.concat(_currentVersion);
196+
return v;
197+
}
198+
199+
int responseCode() {
200+
return _responseCode;
201+
}
202+
int responseHeaders() {
203+
return _responseHeaderCount;
204+
}
205+
const String &responseHeader(String name);
206+
const String &responseHeader(int i);
207+
const String &responseHeaderName(int i);
208+
209+
bool hasResponseHeader(String name); // check if response header exists
189210

190211
int clientContentLength() {
191212
return _clientContentLength;
@@ -228,6 +249,8 @@ class WebServer {
228249
bool _eTagEnabled = false;
229250
ETagFunction _eTagFunction = nullptr;
230251

252+
static String responseCodeToString(int code);
253+
231254
protected:
232255
virtual size_t _currentClientWrite(const char *b, size_t l) {
233256
return _currentClient.write(b, l);
@@ -241,7 +264,6 @@ class WebServer {
241264
void _finalizeResponse();
242265
bool _parseRequest(NetworkClient &client);
243266
void _parseArguments(String data);
244-
static String _responseCodeToString(int code);
245267
bool _parseForm(NetworkClient &client, String boundary, uint32_t len);
246268
bool _parseFormUploadAborted();
247269
void _uploadWriteByte(uint8_t b);
@@ -255,44 +277,51 @@ class WebServer {
255277
// for extracting Auth parameters
256278
String _extractParam(String &authReq, const String &param, const char delimit = '"');
257279

280+
void _clearResponseHeaders();
281+
void _clearRequestHeaders();
282+
258283
struct RequestArgument {
259284
String key;
260285
String value;
286+
RequestArgument *next;
261287
};
262288

263-
boolean _corsEnabled;
289+
boolean _corsEnabled = false;
264290
NetworkServer _server;
265291

266292
NetworkClient _currentClient;
267-
HTTPMethod _currentMethod;
293+
HTTPMethod _currentMethod = HTTP_ANY;
268294
String _currentUri;
269-
uint8_t _currentVersion;
270-
HTTPClientStatus _currentStatus;
271-
unsigned long _statusChange;
272-
boolean _nullDelay;
273-
274-
RequestHandler *_currentHandler;
275-
RequestHandler *_firstHandler;
276-
RequestHandler *_lastHandler;
277-
THandlerFunction _notFoundHandler;
278-
THandlerFunction _fileUploadHandler;
279-
280-
int _currentArgCount;
281-
RequestArgument *_currentArgs;
282-
int _postArgsLen;
283-
RequestArgument *_postArgs;
295+
uint8_t _currentVersion = 0;
296+
HTTPClientStatus _currentStatus = HC_NONE;
297+
unsigned long _statusChange = 0;
298+
boolean _nullDelay = true;
299+
300+
RequestHandler *_currentHandler = nullptr;
301+
RequestHandler *_firstHandler = nullptr;
302+
RequestHandler *_lastHandler = nullptr;
303+
THandlerFunction _notFoundHandler = nullptr;
304+
THandlerFunction _fileUploadHandler = nullptr;
305+
306+
int _currentArgCount = 0;
307+
RequestArgument *_currentArgs = nullptr;
308+
int _postArgsLen = 0;
309+
RequestArgument *_postArgs = nullptr;
284310

285311
std::unique_ptr<HTTPUpload> _currentUpload;
286312
std::unique_ptr<HTTPRaw> _currentRaw;
287313

288-
int _headerKeysCount;
289-
RequestArgument *_currentHeaders;
290-
size_t _contentLength;
291-
int _clientContentLength; // "Content-Length" from header of incoming POST or GET request
292-
String _responseHeaders;
314+
bool _collectAllHeaders = false;
315+
int _headerKeysCount = 0;
316+
RequestArgument *_currentHeaders = nullptr;
317+
size_t _contentLength = 0;
318+
int _clientContentLength = 0; // "Content-Length" from header of incoming POST or GET request
319+
RequestArgument *_responseHeaders = nullptr;
320+
int _responseHeaderCount = 0;
321+
int _responseCode = 0;
293322

294323
String _hostHeader;
295-
bool _chunked;
324+
bool _chunked = false;
296325

297326
String _snonce; // Store noance and opaque for future comparison
298327
String _sopaque;
Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
#ifndef MIDDLEWARE_H
2+
#define MIDDLEWARE_H
3+
4+
#include <assert.h>
5+
6+
class MiddlewareChain;
7+
class WebServer;
8+
9+
class Middleware {
10+
public:
11+
typedef std::function<bool(void)> Callback;
12+
typedef std::function<bool(WebServer &server, Callback next)> Function;
13+
virtual ~Middleware() {}
14+
virtual bool run(WebServer &server, Callback next) {
15+
return next();
16+
};
17+
18+
private:
19+
friend MiddlewareChain;
20+
Middleware *_next = nullptr;
21+
bool _managed = false;
22+
};
23+
24+
class MiddlewareHandle : public Middleware {
25+
private:
26+
Middleware::Function _fn;
27+
28+
public:
29+
MiddlewareHandle(Middleware::Function fn) : _fn(fn) {
30+
assert(fn);
31+
}
32+
33+
bool run(WebServer &server, Middleware::Callback next) override {
34+
return _fn(server, next);
35+
}
36+
};
37+
38+
class MiddlewareChain {
39+
private:
40+
Middleware *_root = nullptr;
41+
Middleware *_current = nullptr;
42+
43+
public:
44+
~MiddlewareChain() {
45+
Middleware *current = _root;
46+
while (current) {
47+
Middleware *next = current->_next;
48+
if (current->_managed) {
49+
delete current;
50+
}
51+
current = next;
52+
}
53+
_root = nullptr;
54+
}
55+
56+
Middleware *add(Middleware *middleware) {
57+
if (!_root) {
58+
_root = middleware;
59+
return middleware;
60+
}
61+
Middleware *current = _root;
62+
while (current->_next) {
63+
current = current->_next;
64+
}
65+
current->_next = middleware;
66+
return middleware;
67+
}
68+
69+
Middleware *add(Middleware::Function fn) {
70+
MiddlewareHandle *middleware = new MiddlewareHandle(fn);
71+
middleware->_managed = true;
72+
return add(middleware);
73+
}
74+
75+
bool remove(Middleware *middleware) {
76+
if (!_root) {
77+
return false;
78+
}
79+
if (_root == middleware) {
80+
_root = _root->_next;
81+
if (middleware->_managed) {
82+
delete middleware;
83+
}
84+
return true;
85+
}
86+
Middleware *current = _root;
87+
while (current->_next) {
88+
if (current->_next == middleware) {
89+
current->_next = current->_next->_next;
90+
if (middleware->_managed) {
91+
delete middleware;
92+
}
93+
return true;
94+
}
95+
current = current->_next;
96+
}
97+
return false;
98+
}
99+
100+
bool run(WebServer &server, Middleware::Callback finalizer) {
101+
if (!_root) {
102+
return finalizer();
103+
}
104+
_current = _root;
105+
Middleware::Callback next;
106+
next = [this, &server, &next, finalizer]() {
107+
if (_current) {
108+
Middleware *that = _current;
109+
_current = _current->_next;
110+
return that->run(server, next);
111+
} else {
112+
return finalizer();
113+
}
114+
};
115+
return next();
116+
}
117+
};
118+
119+
#endif

‎libraries/WebServer/src/detail/RequestHandler.h

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,11 @@
66

77
class RequestHandler {
88
public:
9-
virtual ~RequestHandler() {}
9+
virtual ~RequestHandler() {
10+
if (_chain) {
11+
delete _chain;
12+
}
13+
}
1014

1115
/*
1216
note: old handler API for backward compatibility
@@ -75,8 +79,35 @@ class RequestHandler {
7579
_next = r;
7680
}
7781

82+
RequestHandler &addMiddleware(Middleware *middleware) {
83+
if (!_chain) {
84+
_chain = new MiddlewareChain();
85+
}
86+
_chain->add(middleware);
87+
return *this;
88+
}
89+
90+
RequestHandler &addMiddleware(Middleware::Function fn) {
91+
if (!_chain) {
92+
_chain = new MiddlewareChain();
93+
}
94+
_chain->add(fn);
95+
return *this;
96+
}
97+
98+
bool process(WebServer &server, HTTPMethod requestMethod, String requestUri) {
99+
if (_chain) {
100+
return _chain->run(server, [this, &server, &requestMethod, &requestUri]() {
101+
return handle(server, requestMethod, requestUri);
102+
});
103+
} else {
104+
return handle(server, requestMethod, requestUri);
105+
}
106+
}
107+
78108
private:
79109
RequestHandler *_next = nullptr;
110+
MiddlewareChain *_chain = nullptr;
80111

81112
protected:
82113
std::vector<String> pathArgs;

0 commit comments

Comments
 (0)
Please sign in to comment.