Skip to content

Commit 5edd3fc

Browse files
authored
Settable confidence thresholds (#25)
* Settable confidence thresholds * Remove debug statements * Pretty print --print-info * Ignore .DS_Store
1 parent d94c873 commit 5edd3fc

File tree

3 files changed

+124
-7
lines changed

3 files changed

+124
-7
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,3 +16,4 @@ opencv/
1616
*.onnx
1717
convert-savedmodel/saved_model/
1818
debug.bmp
19+
.DS_Store

source/custom.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,10 @@ int main(int argc, char **argv) {
8585
numpy::signal_from_buffer(&raw_features[0], raw_features.size(), &signal);
8686

8787
EI_IMPULSE_ERROR res = run_classifier(&signal, &result, false);
88+
if (res != EI_IMPULSE_OK) {
89+
printf("run_classifier failed (%d)\n", (int)res);
90+
return 1;
91+
}
8892
// print the predictions
8993
printf("Predictions (DSP: %d ms., Classification: %d ms., Anomaly: %d ms.): \n",
9094
result.timing.dsp, result.timing.classification, result.timing.anomaly);
@@ -119,7 +123,7 @@ int main(int argc, char **argv) {
119123
printf("\n");
120124
}
121125
#endif
122-
#if EI_CLASSIFIER_HAS_ANOMALY == 3 // visual AD
126+
#if EI_CLASSIFIER_HAS_ANOMALY == EI_ANOMALY_TYPE_VISUAL_GMM // visual AD
123127
printf("#Visual anomaly grid results:\n");
124128
for (uint32_t i = 0; i < result.visual_ad_count; i++) {
125129
ei_impulse_result_bounding_box_t bb = result.visual_ad_grid_cells[i];

source/eim.cpp

Lines changed: 118 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,14 @@
33
#include <cstring>
44
#include <iostream>
55
#include <sstream>
6+
// #include <thread>
67
#include <chrono>
78
#include <vector>
89
#include "edge-impulse-sdk/classifier/ei_run_classifier.h"
910
#include "json/json.hpp"
1011
#include "rapidjson/document.h"
12+
#include "rapidjson/prettywriter.h"
13+
#include "rapidjson/stringbuffer.h"
1114
#include <sys/socket.h>
1215
#include <sys/un.h>
1316
#include <unistd.h>
@@ -165,11 +168,12 @@ void json_message_handler(rapidjson::Document &msg, char *resp_buffer, size_t re
165168

166169
auto id = id_v.GetInt();
167170

168-
rapidjson::Value& version = msg["hello"];
171+
rapidjson::Value& hello = msg["hello"];
169172
rapidjson::Value& classify_data = msg["classify"];
170173
rapidjson::Value& classify_data_continuous = msg["classify_continuous"];
174+
rapidjson::Value& set_threshold = msg["set_threshold"];
171175

172-
if (version.IsInt()) {
176+
if (hello.IsInt()) {
173177
if (state.initialized) {
174178
nlohmann::json err = {
175179
{"id", id},
@@ -180,11 +184,11 @@ void json_message_handler(rapidjson::Document &msg, char *resp_buffer, size_t re
180184
return;
181185
}
182186

183-
if (version.GetInt() != 1) {
187+
if (hello.GetInt() != 1) {
184188
nlohmann::json err = {
185189
{"id", id},
186190
{"success", false},
187-
{"error", "Invalid 'version', only 1 supported"},
191+
{"error", "Invalid value for 'hello', only 1 supported"},
188192
};
189193
snprintf(resp_buffer, resp_buffer_size, "%s\n", err.dump().c_str());
190194
return;
@@ -219,6 +223,30 @@ void json_message_handler(rapidjson::Document &msg, char *resp_buffer, size_t re
219223
const char *model_type = "classification";
220224
#endif // EI_CLASSIFIER_OBJECT_DETECTION
221225

226+
// keep track of configurable thresholds
227+
nlohmann::json thresholds = nlohmann::json::array();
228+
for (size_t ix = 0; ix < ei_learning_blocks_size; ix++) {
229+
const ei_learning_block_t learn_block = ei_learning_blocks[ix];
230+
if (learn_block.infer_fn == run_gmm_anomaly) {
231+
ei_learning_block_config_anomaly_gmm_t *config = (ei_learning_block_config_anomaly_gmm_t*)learn_block.config;
232+
thresholds.push_back({
233+
{ "id", learn_block.blockId },
234+
{ "type", "anomaly_gmm" },
235+
{ "min_anomaly_score", config->anomaly_threshold }
236+
});
237+
}
238+
else if (learn_block.infer_fn == run_nn_inference) {
239+
ei_learning_block_config_tflite_graph_t *config = (ei_learning_block_config_tflite_graph_t*)learn_block.config;
240+
if (config->classification_mode == EI_CLASSIFIER_CLASSIFICATION_MODE_OBJECT_DETECTION) {
241+
thresholds.push_back({
242+
{ "id", learn_block.blockId },
243+
{ "type", "object_detection" },
244+
{ "min_score", config->threshold }
245+
});
246+
}
247+
}
248+
}
249+
222250
nlohmann::json resp = {
223251
{"id", id},
224252
{"success", true},
@@ -246,13 +274,14 @@ void json_message_handler(rapidjson::Document &msg, char *resp_buffer, size_t re
246274
{"slice_size", EI_CLASSIFIER_SLICE_SIZE},
247275
{"use_continuous_mode", EI_CLASSIFIER_SENSOR == EI_CLASSIFIER_SENSOR_MICROPHONE},
248276
{"inferencing_engine", EI_CLASSIFIER_INFERENCING_ENGINE},
277+
{"thresholds", thresholds},
249278
}},
250279
};
251280

252281
snprintf(resp_buffer, resp_buffer_size, "%s\n", resp.dump().c_str());
253282

254283
state.initialized = true;
255-
state.version = version.GetInt();
284+
state.version = hello.GetInt();
256285
}
257286
else if (!state.initialized) {
258287
nlohmann::json err = {
@@ -353,6 +382,59 @@ void json_message_handler(rapidjson::Document &msg, char *resp_buffer, size_t re
353382
json_send_classification_response(id, json_parsing_ms, stdin_ms,
354383
res, &result, resp_buffer, resp_buffer_size);
355384
}
385+
else if (set_threshold.IsObject()) {
386+
if (!set_threshold.HasMember("id") || !set_threshold["id"].IsInt()) {
387+
nlohmann::json err = {
388+
{"id", id},
389+
{"success", false},
390+
{"error", "set_threshold should have a numeric field 'id'"},
391+
};
392+
snprintf(resp_buffer, resp_buffer_size, "%s\n", err.dump().c_str());
393+
return;
394+
}
395+
396+
bool found_block = false;
397+
int block_id = set_threshold["id"].GetInt();
398+
for (size_t ix = 0; ix < ei_learning_blocks_size; ix++) {
399+
const ei_learning_block_t learn_block = ei_learning_blocks[ix];
400+
if (learn_block.blockId != block_id) continue;
401+
402+
found_block = true;
403+
404+
if (learn_block.infer_fn == run_gmm_anomaly) {
405+
ei_learning_block_config_anomaly_gmm_t *config = (ei_learning_block_config_anomaly_gmm_t*)learn_block.config;
406+
407+
if (set_threshold.HasMember("min_anomaly_score") && set_threshold["min_anomaly_score"].IsNumber()) {
408+
config->anomaly_threshold = set_threshold["min_anomaly_score"].GetFloat();
409+
}
410+
}
411+
else if (learn_block.infer_fn == run_nn_inference) {
412+
ei_learning_block_config_tflite_graph_t *config = (ei_learning_block_config_tflite_graph_t*)learn_block.config;
413+
if (config->classification_mode == EI_CLASSIFIER_CLASSIFICATION_MODE_OBJECT_DETECTION) {
414+
if (set_threshold.HasMember("min_score") && set_threshold["min_score"].IsNumber()) {
415+
config->threshold = set_threshold["min_score"].GetFloat();
416+
}
417+
}
418+
}
419+
}
420+
421+
if (!found_block) {
422+
nlohmann::json err = {
423+
{"id", id},
424+
{"success", false},
425+
{"error", "set_threshold: cannot find learn block with this id"},
426+
};
427+
snprintf(resp_buffer, resp_buffer_size, "%s\n", err.dump().c_str());
428+
return;
429+
}
430+
431+
nlohmann::json resp = {
432+
{"id", id},
433+
{"success", true},
434+
};
435+
snprintf(resp_buffer, resp_buffer_size, "%s\n", resp.dump().c_str());
436+
return;
437+
}
356438
else {
357439
nlohmann::json err = {
358440
{"id", id},
@@ -364,6 +446,32 @@ void json_message_handler(rapidjson::Document &msg, char *resp_buffer, size_t re
364446
}
365447
}
366448

449+
int print_metadata_main() {
450+
char output_buffer[100 * 1024] = { 0 };
451+
452+
// Construct a hello msg into output_buffer
453+
{
454+
rapidjson::Document msg;
455+
msg.SetObject();
456+
rapidjson::Document::AllocatorType& allocator = msg.GetAllocator();
457+
msg.AddMember("id", 1, allocator);
458+
msg.AddMember("hello", 1, allocator);
459+
json_message_handler(msg, output_buffer, 100 * 1024, 0, 0);
460+
}
461+
462+
// pretty print (by first parsing, then re-printing)
463+
{
464+
rapidjson::Document document;
465+
document.Parse(output_buffer);
466+
rapidjson::StringBuffer buffer;
467+
rapidjson::PrettyWriter<rapidjson::StringBuffer> writer(buffer);
468+
document.Accept(writer);
469+
printf("%s\n", buffer.GetString());
470+
}
471+
472+
return 0;
473+
}
474+
367475
int stdin_main() {
368476
static char *stdin_buffer = (char *)malloc(STDIN_BUFFER_SIZE);
369477
static char *response_buffer = (char *)calloc(STDIN_BUFFER_SIZE, 1);
@@ -559,12 +667,16 @@ int main(int argc, char **argv) {
559667
setvbuf(stdout, NULL, _IONBF, 0);
560668

561669
if (argc < 2) {
562-
printf("Requires one parameter (either: 'stdin' or the name of a socket)\n");
670+
printf("Requires one parameter (either: '--print-info', 'stdin' or the name of a socket)\n");
563671
return 1;
564672
}
565673

566674
state.initialized = false;
567675

676+
if (strcmp(argv[1], "--print-info") == 0) {
677+
printf("Edge Impulse Linux impulse runner - printing model metadata\n");
678+
return print_metadata_main();
679+
}
568680
if (strcmp(argv[1], "stdin") == 0) {
569681
printf("Edge Impulse Linux impulse runner - listening for JSON messages on stdin\n");
570682
return stdin_main();

0 commit comments

Comments
 (0)