yolo_seg_mqtt_async.cpp 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126
  1. #include "yolo_seg_mqtt_async.h"
  2. void YoloSegmentMqttAsyncClient::init(const std::string &file) {
  3. m_message_arrived_callback_ = CloudDataArrived;
  4. #ifdef ENABLE_TENSORRT_DETECT
  5. detector = new TensorrtWheelDetector(ETC_PATH PROJECT_NAME "/model.engine", ETC_PATH PROJECT_NAME "/class.txt");
  6. #else
  7. detector = new TensorrtWheelDetector(ETC_PATH PROJECT_NAME "/model.onnx", ETC_PATH PROJECT_NAME "/class.txt");
  8. #endif
  9. MqttAsyncClient::init_from_proto(file);
  10. }
  11. bool isElliptical(float x, float y, float a2, float b2) {
  12. return (x * x / a2 + y * y / b2) < 1;
  13. }
  14. uint32_t getPixelIndex(int row, int col) {
  15. // LOG(INFO) << "col " << col << ", row " << row;
  16. col = MAX(0, MIN(1280, col)) % 640;
  17. row = MAX(0, MIN(960, row)) % 480;
  18. // LOG(INFO) << "col " << col << ", row " << row;
  19. return row * 640 + col;
  20. }
  21. // 识别结果数据包含识别结果、时间序号、具体识别信息
  22. int YoloSegmentMqttAsyncClient::CloudDataArrived(void *client, char *topicName, int topicLen,
  23. MQTTAsync_message *message) {
  24. auto *tof_client = (YoloSegmentMqttAsyncClient *) client;
  25. // TODO:测试
  26. std::string topic = topicName;
  27. if (topic == "tof/ir") {
  28. // 数据反序列化
  29. int lable;
  30. cv::Mat merge_mat;
  31. JetStream::LabelImage testimage2;
  32. JetStream::LabelYolo seg_results;
  33. testimage2.ParseFromArray(message->payload, message->payloadlen);
  34. NetMessageTrans::Proto2lableMat(testimage2, lable, merge_mat);
  35. seg_results.set_label(lable);
  36. seg_results.add_boxes();
  37. seg_results.add_boxes();
  38. seg_results.add_boxes();
  39. seg_results.add_boxes();
  40. // 模型识别
  41. std::vector<Object> objs;
  42. cv::cvtColor(merge_mat, merge_mat, cv::COLOR_GRAY2RGB);
  43. tof_client->detector->detect(merge_mat, objs, merge_mat);
  44. // cv::imshow("merge_mat", merge_mat);
  45. // cv::waitKey(1);
  46. // 分离识别结果
  47. auto t = std::chrono::steady_clock::now();
  48. for (auto iter = objs.begin(); iter != objs.end(); iter++) {
  49. auto seg_points = tof_client->detector->getPointsFromObj(*iter);
  50. int device_index = (int(iter->rect.x / 640) * 0x01) | (int(iter->rect.y / 480) << 1);
  51. // 校验识别矩形框是否越界
  52. int device_index_check = (int((iter->rect.x + iter->rect.width) / 640) * 0x01) |
  53. (int((iter->rect.y + iter->rect.height) / 480) << 1);
  54. if (device_index != device_index_check) {
  55. // TODO:存图
  56. objs.erase(iter);
  57. iter--;
  58. continue;
  59. }
  60. int x_alpha = (device_index & 0x1);
  61. int y_alpha = ((device_index & 0x2) >> 1);
  62. seg_results.mutable_boxes(device_index)->set_x(iter->rect.x - x_alpha * merge_mat.cols / 2);
  63. seg_results.mutable_boxes(device_index)->set_y(iter->rect.y - y_alpha * merge_mat.rows / 2);
  64. seg_results.mutable_boxes(device_index)->set_width(iter->rect.width);
  65. seg_results.mutable_boxes(device_index)->set_height(iter->rect.height);
  66. seg_results.mutable_boxes(device_index)->set_confidence(iter->prob);
  67. if (iter->prob > 0.6) {
  68. // 内外椭圆参数
  69. float a1 = iter->rect.width * 0.5;
  70. float b1 = iter->rect.height * 0.5;
  71. float a2 = iter->rect.width * 0.57;
  72. float b2 = iter->rect.height * 0.57;
  73. //
  74. int center_x = iter->rect.x + iter->rect.width * 0.5;
  75. int center_y = iter->rect.y + iter->rect.height * 0.5;
  76. int min_x = MAX(0, center_x - a2);
  77. int min_y = MAX(0, center_y - b2);
  78. float a12 = a1 * a1;
  79. float b12 = b1 * b1;
  80. float a22 = a2 * a2;
  81. float b22 = b2 * b2;
  82. for (int x_value = min_x; x_value < center_x; x_value++) {
  83. auto t_x_value = x_value - center_x;
  84. for (int y_value = min_y; y_value < center_y; y_value++) {
  85. auto t_y_value = y_value - center_y;
  86. if (!isElliptical(t_x_value, t_y_value, a22, b22)) {
  87. continue;
  88. }
  89. auto x_value_sym = 2 * center_x - x_value;
  90. auto y_value_sym = 2 * center_y - y_value;
  91. if (isElliptical(t_x_value, t_y_value, a12, b12)) {
  92. seg_results.mutable_boxes(device_index)->add_wheels()->set_point(getPixelIndex(y_value, x_value));
  93. seg_results.mutable_boxes(device_index)->add_wheels()->set_point(getPixelIndex(y_value, x_value_sym));
  94. seg_results.mutable_boxes(device_index)->add_wheels()->set_point(getPixelIndex(y_value_sym, x_value));
  95. seg_results.mutable_boxes(device_index)->add_wheels()->set_point(getPixelIndex(y_value_sym, x_value_sym));
  96. continue;
  97. }
  98. seg_results.mutable_boxes(device_index)->add_tires()->set_point(getPixelIndex(y_value, x_value));
  99. seg_results.mutable_boxes(device_index)->add_tires()->set_point(getPixelIndex(y_value, x_value_sym));
  100. seg_results.mutable_boxes(device_index)->add_tires()->set_point(getPixelIndex(y_value_sym, x_value));
  101. seg_results.mutable_boxes(device_index)->add_tires()->set_point(getPixelIndex(y_value_sym, x_value_sym));
  102. }
  103. }
  104. }
  105. }
  106. // cv::imshow("merge_mat", merge_mat);
  107. // cv::waitKey(1);
  108. char data[seg_results.ByteSizeLong()];
  109. seg_results.SerializeToArray((void *) data, seg_results.ByteSizeLong());
  110. tof_client->SendMessage("tof/seg", data, seg_results.ByteSizeLong());
  111. }
  112. return 1;
  113. }