3
3
#include < cstring>
4
4
#include < iostream>
5
5
#include < sstream>
6
+ // #include <thread>
6
7
#include < chrono>
7
8
#include < vector>
8
9
#include " edge-impulse-sdk/classifier/ei_run_classifier.h"
9
10
#include " json/json.hpp"
10
11
#include " rapidjson/document.h"
12
+ #include " rapidjson/prettywriter.h"
13
+ #include " rapidjson/stringbuffer.h"
11
14
#include < sys/socket.h>
12
15
#include < sys/un.h>
13
16
#include < unistd.h>
@@ -165,11 +168,12 @@ void json_message_handler(rapidjson::Document &msg, char *resp_buffer, size_t re
165
168
166
169
auto id = id_v.GetInt ();
167
170
168
- rapidjson::Value& version = msg[" hello" ];
171
+ rapidjson::Value& hello = msg[" hello" ];
169
172
rapidjson::Value& classify_data = msg[" classify" ];
170
173
rapidjson::Value& classify_data_continuous = msg[" classify_continuous" ];
174
+ rapidjson::Value& set_threshold = msg[" set_threshold" ];
171
175
172
- if (version .IsInt ()) {
176
+ if (hello .IsInt ()) {
173
177
if (state.initialized ) {
174
178
nlohmann::json err = {
175
179
{" id" , id},
@@ -180,11 +184,11 @@ void json_message_handler(rapidjson::Document &msg, char *resp_buffer, size_t re
180
184
return ;
181
185
}
182
186
183
- if (version .GetInt () != 1 ) {
187
+ if (hello .GetInt () != 1 ) {
184
188
nlohmann::json err = {
185
189
{" id" , id},
186
190
{" success" , false },
187
- {" error" , " Invalid 'version ', only 1 supported" },
191
+ {" error" , " Invalid value for 'hello ', only 1 supported" },
188
192
};
189
193
snprintf (resp_buffer, resp_buffer_size, " %s\n " , err.dump ().c_str ());
190
194
return ;
@@ -219,6 +223,30 @@ void json_message_handler(rapidjson::Document &msg, char *resp_buffer, size_t re
219
223
const char *model_type = " classification" ;
220
224
#endif // EI_CLASSIFIER_OBJECT_DETECTION
221
225
226
+ // keep track of configurable thresholds
227
+ nlohmann::json thresholds = nlohmann::json::array ();
228
+ for (size_t ix = 0 ; ix < ei_learning_blocks_size; ix++) {
229
+ const ei_learning_block_t learn_block = ei_learning_blocks[ix];
230
+ if (learn_block.infer_fn == run_gmm_anomaly) {
231
+ ei_learning_block_config_anomaly_gmm_t *config = (ei_learning_block_config_anomaly_gmm_t *)learn_block.config ;
232
+ thresholds.push_back ({
233
+ { " id" , learn_block.blockId },
234
+ { " type" , " anomaly_gmm" },
235
+ { " min_anomaly_score" , config->anomaly_threshold }
236
+ });
237
+ }
238
+ else if (learn_block.infer_fn == run_nn_inference) {
239
+ ei_learning_block_config_tflite_graph_t *config = (ei_learning_block_config_tflite_graph_t *)learn_block.config ;
240
+ if (config->classification_mode == EI_CLASSIFIER_CLASSIFICATION_MODE_OBJECT_DETECTION) {
241
+ thresholds.push_back ({
242
+ { " id" , learn_block.blockId },
243
+ { " type" , " object_detection" },
244
+ { " min_score" , config->threshold }
245
+ });
246
+ }
247
+ }
248
+ }
249
+
222
250
nlohmann::json resp = {
223
251
{" id" , id},
224
252
{" success" , true },
@@ -246,13 +274,14 @@ void json_message_handler(rapidjson::Document &msg, char *resp_buffer, size_t re
246
274
{" slice_size" , EI_CLASSIFIER_SLICE_SIZE},
247
275
{" use_continuous_mode" , EI_CLASSIFIER_SENSOR == EI_CLASSIFIER_SENSOR_MICROPHONE},
248
276
{" inferencing_engine" , EI_CLASSIFIER_INFERENCING_ENGINE},
277
+ {" thresholds" , thresholds},
249
278
}},
250
279
};
251
280
252
281
snprintf (resp_buffer, resp_buffer_size, " %s\n " , resp.dump ().c_str ());
253
282
254
283
state.initialized = true ;
255
- state.version = version .GetInt ();
284
+ state.version = hello .GetInt ();
256
285
}
257
286
else if (!state.initialized ) {
258
287
nlohmann::json err = {
@@ -353,6 +382,59 @@ void json_message_handler(rapidjson::Document &msg, char *resp_buffer, size_t re
353
382
json_send_classification_response (id, json_parsing_ms, stdin_ms,
354
383
res, &result, resp_buffer, resp_buffer_size);
355
384
}
385
+ else if (set_threshold.IsObject ()) {
386
+ if (!set_threshold.HasMember (" id" ) || !set_threshold[" id" ].IsInt ()) {
387
+ nlohmann::json err = {
388
+ {" id" , id},
389
+ {" success" , false },
390
+ {" error" , " set_threshold should have a numeric field 'id'" },
391
+ };
392
+ snprintf (resp_buffer, resp_buffer_size, " %s\n " , err.dump ().c_str ());
393
+ return ;
394
+ }
395
+
396
+ bool found_block = false ;
397
+ int block_id = set_threshold[" id" ].GetInt ();
398
+ for (size_t ix = 0 ; ix < ei_learning_blocks_size; ix++) {
399
+ const ei_learning_block_t learn_block = ei_learning_blocks[ix];
400
+ if (learn_block.blockId != block_id) continue ;
401
+
402
+ found_block = true ;
403
+
404
+ if (learn_block.infer_fn == run_gmm_anomaly) {
405
+ ei_learning_block_config_anomaly_gmm_t *config = (ei_learning_block_config_anomaly_gmm_t *)learn_block.config ;
406
+
407
+ if (set_threshold.HasMember (" min_anomaly_score" ) && set_threshold[" min_anomaly_score" ].IsNumber ()) {
408
+ config->anomaly_threshold = set_threshold[" min_anomaly_score" ].GetFloat ();
409
+ }
410
+ }
411
+ else if (learn_block.infer_fn == run_nn_inference) {
412
+ ei_learning_block_config_tflite_graph_t *config = (ei_learning_block_config_tflite_graph_t *)learn_block.config ;
413
+ if (config->classification_mode == EI_CLASSIFIER_CLASSIFICATION_MODE_OBJECT_DETECTION) {
414
+ if (set_threshold.HasMember (" min_score" ) && set_threshold[" min_score" ].IsNumber ()) {
415
+ config->threshold = set_threshold[" min_score" ].GetFloat ();
416
+ }
417
+ }
418
+ }
419
+ }
420
+
421
+ if (!found_block) {
422
+ nlohmann::json err = {
423
+ {" id" , id},
424
+ {" success" , false },
425
+ {" error" , " set_threshold: cannot find learn block with this id" },
426
+ };
427
+ snprintf (resp_buffer, resp_buffer_size, " %s\n " , err.dump ().c_str ());
428
+ return ;
429
+ }
430
+
431
+ nlohmann::json resp = {
432
+ {" id" , id},
433
+ {" success" , true },
434
+ };
435
+ snprintf (resp_buffer, resp_buffer_size, " %s\n " , resp.dump ().c_str ());
436
+ return ;
437
+ }
356
438
else {
357
439
nlohmann::json err = {
358
440
{" id" , id},
@@ -364,6 +446,32 @@ void json_message_handler(rapidjson::Document &msg, char *resp_buffer, size_t re
364
446
}
365
447
}
366
448
449
+ int print_metadata_main () {
450
+ char output_buffer[100 * 1024 ] = { 0 };
451
+
452
+ // Construct a hello msg into output_buffer
453
+ {
454
+ rapidjson::Document msg;
455
+ msg.SetObject ();
456
+ rapidjson::Document::AllocatorType& allocator = msg.GetAllocator ();
457
+ msg.AddMember (" id" , 1 , allocator);
458
+ msg.AddMember (" hello" , 1 , allocator);
459
+ json_message_handler (msg, output_buffer, 100 * 1024 , 0 , 0 );
460
+ }
461
+
462
+ // pretty print (by first parsing, then re-printing)
463
+ {
464
+ rapidjson::Document document;
465
+ document.Parse (output_buffer);
466
+ rapidjson::StringBuffer buffer;
467
+ rapidjson::PrettyWriter<rapidjson::StringBuffer> writer (buffer);
468
+ document.Accept (writer);
469
+ printf (" %s\n " , buffer.GetString ());
470
+ }
471
+
472
+ return 0 ;
473
+ }
474
+
367
475
int stdin_main () {
368
476
static char *stdin_buffer = (char *)malloc (STDIN_BUFFER_SIZE);
369
477
static char *response_buffer = (char *)calloc (STDIN_BUFFER_SIZE, 1 );
@@ -559,12 +667,16 @@ int main(int argc, char **argv) {
559
667
setvbuf (stdout, NULL , _IONBF, 0 );
560
668
561
669
if (argc < 2 ) {
562
- printf (" Requires one parameter (either: 'stdin' or the name of a socket)\n " );
670
+ printf (" Requires one parameter (either: '--print-info', ' stdin' or the name of a socket)\n " );
563
671
return 1 ;
564
672
}
565
673
566
674
state.initialized = false ;
567
675
676
+ if (strcmp (argv[1 ], " --print-info" ) == 0 ) {
677
+ printf (" Edge Impulse Linux impulse runner - printing model metadata\n " );
678
+ return print_metadata_main ();
679
+ }
568
680
if (strcmp (argv[1 ], " stdin" ) == 0 ) {
569
681
printf (" Edge Impulse Linux impulse runner - listening for JSON messages on stdin\n " );
570
682
return stdin_main ();
0 commit comments