Skip to content

Settable confidence thresholds #25

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Feb 17, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,4 @@ opencv/
*.onnx
convert-savedmodel/saved_model/
debug.bmp
.DS_Store
6 changes: 5 additions & 1 deletion source/custom.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,10 @@ int main(int argc, char **argv) {
numpy::signal_from_buffer(&raw_features[0], raw_features.size(), &signal);

EI_IMPULSE_ERROR res = run_classifier(&signal, &result, false);
if (res != EI_IMPULSE_OK) {
printf("run_classifier failed (%d)\n", (int)res);
return 1;
}
// print the predictions
printf("Predictions (DSP: %d ms., Classification: %d ms., Anomaly: %d ms.): \n",
result.timing.dsp, result.timing.classification, result.timing.anomaly);
Expand Down Expand Up @@ -119,7 +123,7 @@ int main(int argc, char **argv) {
printf("\n");
}
#endif
#if EI_CLASSIFIER_HAS_ANOMALY == 3 // visual AD
#if EI_CLASSIFIER_HAS_ANOMALY == EI_ANOMALY_TYPE_VISUAL_GMM // visual AD
printf("#Visual anomaly grid results:\n");
for (uint32_t i = 0; i < result.visual_ad_count; i++) {
ei_impulse_result_bounding_box_t bb = result.visual_ad_grid_cells[i];
Expand Down
124 changes: 118 additions & 6 deletions source/eim.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,14 @@
#include <cstring>
#include <iostream>
#include <sstream>
// #include <thread>
#include <chrono>
#include <vector>
#include "edge-impulse-sdk/classifier/ei_run_classifier.h"
#include "json/json.hpp"
#include "rapidjson/document.h"
#include "rapidjson/prettywriter.h"
#include "rapidjson/stringbuffer.h"
#include <sys/socket.h>
#include <sys/un.h>
#include <unistd.h>
Expand Down Expand Up @@ -165,11 +168,12 @@ void json_message_handler(rapidjson::Document &msg, char *resp_buffer, size_t re

auto id = id_v.GetInt();

rapidjson::Value& version = msg["hello"];
rapidjson::Value& hello = msg["hello"];
rapidjson::Value& classify_data = msg["classify"];
rapidjson::Value& classify_data_continuous = msg["classify_continuous"];
rapidjson::Value& set_threshold = msg["set_threshold"];

if (version.IsInt()) {
if (hello.IsInt()) {
if (state.initialized) {
nlohmann::json err = {
{"id", id},
Expand All @@ -180,11 +184,11 @@ void json_message_handler(rapidjson::Document &msg, char *resp_buffer, size_t re
return;
}

if (version.GetInt() != 1) {
if (hello.GetInt() != 1) {
nlohmann::json err = {
{"id", id},
{"success", false},
{"error", "Invalid 'version', only 1 supported"},
{"error", "Invalid value for 'hello', only 1 supported"},
};
snprintf(resp_buffer, resp_buffer_size, "%s\n", err.dump().c_str());
return;
Expand Down Expand Up @@ -219,6 +223,30 @@ void json_message_handler(rapidjson::Document &msg, char *resp_buffer, size_t re
const char *model_type = "classification";
#endif // EI_CLASSIFIER_OBJECT_DETECTION

// keep track of configurable thresholds
nlohmann::json thresholds = nlohmann::json::array();
for (size_t ix = 0; ix < ei_learning_blocks_size; ix++) {
const ei_learning_block_t learn_block = ei_learning_blocks[ix];
if (learn_block.infer_fn == run_gmm_anomaly) {
ei_learning_block_config_anomaly_gmm_t *config = (ei_learning_block_config_anomaly_gmm_t*)learn_block.config;
thresholds.push_back({
{ "id", learn_block.blockId },
{ "type", "anomaly_gmm" },
{ "min_anomaly_score", config->anomaly_threshold }
});
}
else if (learn_block.infer_fn == run_nn_inference) {
ei_learning_block_config_tflite_graph_t *config = (ei_learning_block_config_tflite_graph_t*)learn_block.config;
if (config->classification_mode == EI_CLASSIFIER_CLASSIFICATION_MODE_OBJECT_DETECTION) {
thresholds.push_back({
{ "id", learn_block.blockId },
{ "type", "object_detection" },
{ "min_score", config->threshold }
});
}
}
}

nlohmann::json resp = {
{"id", id},
{"success", true},
Expand Down Expand Up @@ -246,13 +274,14 @@ void json_message_handler(rapidjson::Document &msg, char *resp_buffer, size_t re
{"slice_size", EI_CLASSIFIER_SLICE_SIZE},
{"use_continuous_mode", EI_CLASSIFIER_SENSOR == EI_CLASSIFIER_SENSOR_MICROPHONE},
{"inferencing_engine", EI_CLASSIFIER_INFERENCING_ENGINE},
{"thresholds", thresholds},
}},
};

snprintf(resp_buffer, resp_buffer_size, "%s\n", resp.dump().c_str());

state.initialized = true;
state.version = version.GetInt();
state.version = hello.GetInt();
}
else if (!state.initialized) {
nlohmann::json err = {
Expand Down Expand Up @@ -353,6 +382,59 @@ void json_message_handler(rapidjson::Document &msg, char *resp_buffer, size_t re
json_send_classification_response(id, json_parsing_ms, stdin_ms,
res, &result, resp_buffer, resp_buffer_size);
}
else if (set_threshold.IsObject()) {
if (!set_threshold.HasMember("id") || !set_threshold["id"].IsInt()) {
nlohmann::json err = {
{"id", id},
{"success", false},
{"error", "set_threshold should have a numeric field 'id'"},
};
snprintf(resp_buffer, resp_buffer_size, "%s\n", err.dump().c_str());
return;
}

bool found_block = false;
int block_id = set_threshold["id"].GetInt();
for (size_t ix = 0; ix < ei_learning_blocks_size; ix++) {
const ei_learning_block_t learn_block = ei_learning_blocks[ix];
if (learn_block.blockId != block_id) continue;

found_block = true;

if (learn_block.infer_fn == run_gmm_anomaly) {
ei_learning_block_config_anomaly_gmm_t *config = (ei_learning_block_config_anomaly_gmm_t*)learn_block.config;

if (set_threshold.HasMember("min_anomaly_score") && set_threshold["min_anomaly_score"].IsNumber()) {
config->anomaly_threshold = set_threshold["min_anomaly_score"].GetFloat();
}
}
else if (learn_block.infer_fn == run_nn_inference) {
ei_learning_block_config_tflite_graph_t *config = (ei_learning_block_config_tflite_graph_t*)learn_block.config;
if (config->classification_mode == EI_CLASSIFIER_CLASSIFICATION_MODE_OBJECT_DETECTION) {
if (set_threshold.HasMember("min_score") && set_threshold["min_score"].IsNumber()) {
config->threshold = set_threshold["min_score"].GetFloat();
}
}
}
}

if (!found_block) {
nlohmann::json err = {
{"id", id},
{"success", false},
{"error", "set_threshold: cannot find learn block with this id"},
};
snprintf(resp_buffer, resp_buffer_size, "%s\n", err.dump().c_str());
return;
}

nlohmann::json resp = {
{"id", id},
{"success", true},
};
snprintf(resp_buffer, resp_buffer_size, "%s\n", resp.dump().c_str());
return;
}
else {
nlohmann::json err = {
{"id", id},
Expand All @@ -364,6 +446,32 @@ void json_message_handler(rapidjson::Document &msg, char *resp_buffer, size_t re
}
}

int print_metadata_main() {
char output_buffer[100 * 1024] = { 0 };

// Construct a hello msg into output_buffer
{
rapidjson::Document msg;
msg.SetObject();
rapidjson::Document::AllocatorType& allocator = msg.GetAllocator();
msg.AddMember("id", 1, allocator);
msg.AddMember("hello", 1, allocator);
json_message_handler(msg, output_buffer, 100 * 1024, 0, 0);
}

// pretty print (by first parsing, then re-printing)
{
rapidjson::Document document;
document.Parse(output_buffer);
rapidjson::StringBuffer buffer;
rapidjson::PrettyWriter<rapidjson::StringBuffer> writer(buffer);
document.Accept(writer);
printf("%s\n", buffer.GetString());
}

return 0;
}

int stdin_main() {
static char *stdin_buffer = (char *)malloc(STDIN_BUFFER_SIZE);
static char *response_buffer = (char *)calloc(STDIN_BUFFER_SIZE, 1);
Expand Down Expand Up @@ -559,12 +667,16 @@ int main(int argc, char **argv) {
setvbuf(stdout, NULL, _IONBF, 0);

if (argc < 2) {
printf("Requires one parameter (either: 'stdin' or the name of a socket)\n");
printf("Requires one parameter (either: '--print-info', 'stdin' or the name of a socket)\n");
return 1;
}

state.initialized = false;

if (strcmp(argv[1], "--print-info") == 0) {
printf("Edge Impulse Linux impulse runner - printing model metadata\n");
return print_metadata_main();
}
if (strcmp(argv[1], "stdin") == 0) {
printf("Edge Impulse Linux impulse runner - listening for JSON messages on stdin\n");
return stdin_main();
Expand Down