yolo_seg_mqtt_async.cpp 4.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091
  1. #include "yolo_seg_mqtt_async.h"
  2. void YoloSegmentMqttAsyncClient::init(const std::string &file) {
  3. m_message_arrived_callback_ = CloudDataArrived;
  4. #ifdef ENABLE_TENSORRT_DETECT
  5. detector = new TensorrtWheelDetector(ETC_PATH PROJECT_NAME "/model.engine", ETC_PATH PROJECT_NAME "/class.txt");
  6. #else
  7. detector = new TensorrtWheelDetector(ETC_PATH PROJECT_NAME "/model.onnx", ETC_PATH PROJECT_NAME "/class.txt");
  8. #endif
  9. MqttAsyncClient::init_from_proto(file);
  10. }
  11. // 识别结果数据包含识别结果、时间序号、具体识别信息
  12. int YoloSegmentMqttAsyncClient::CloudDataArrived(void *client, char *topicName, int topicLen,
  13. MQTTAsync_message *message) {
  14. auto *tof_client = (YoloSegmentMqttAsyncClient *) client;
  15. // TODO:测试
  16. std::string topic = topicName;
  17. if (topic == "tof/ir") {
  18. // 数据反序列化
  19. int lable;
  20. cv::Mat merge_mat;
  21. JetStream::LabelImage testimage2;
  22. JetStream::LabelYolo seg_results;
  23. testimage2.ParseFromArray(message->payload, message->payloadlen);
  24. NetMessageTrans::Proto2lableMat(testimage2, lable, merge_mat);
  25. seg_results.set_label(lable);
  26. seg_results.add_boxes();
  27. seg_results.add_boxes();
  28. seg_results.add_boxes();
  29. seg_results.add_boxes();
  30. // 模型识别
  31. std::vector<Object> objs;
  32. cv::cvtColor(merge_mat, merge_mat, cv::COLOR_GRAY2RGB);
  33. tof_client->detector->detect(merge_mat, objs, merge_mat);
  34. // cv::imshow("merge_mat", merge_mat);
  35. // cv::waitKey(1);
  36. // 分离识别结果
  37. auto t = std::chrono::steady_clock::now();
  38. for (auto iter = objs.begin(); iter != objs.end(); iter++) {
  39. auto seg_points = tof_client->detector->getPointsFromObj(*iter);
  40. int device_index = (int(iter->rect.x / 640) * 0x01) | (int(iter->rect.y / 480) << 1);
  41. // 校验识别矩形框是否越界
  42. int device_index_check = (int((iter->rect.x + iter->rect.width) / 640) * 0x01) |
  43. (int((iter->rect.y + iter->rect.height) / 480) << 1);
  44. if (device_index != device_index_check) {
  45. // TODO:存图
  46. objs.erase(iter);
  47. iter--;
  48. continue;
  49. }
  50. int x_alpha = (device_index & 0x1);
  51. int y_alpha = ((device_index & 0x2) >> 1);
  52. seg_results.mutable_boxes(device_index)->set_x(iter->rect.x - x_alpha * merge_mat.cols / 2);
  53. seg_results.mutable_boxes(device_index)->set_y(iter->rect.y - y_alpha * merge_mat.rows / 2);
  54. seg_results.mutable_boxes(device_index)->set_width(iter->rect.width);
  55. seg_results.mutable_boxes(device_index)->set_height(iter->rect.height);
  56. seg_results.mutable_boxes(device_index)->set_confidence(iter->prob);
  57. if (iter->prob > 0.6) {
  58. // LOG(INFO) << device_index << ": " << iter->prob;
  59. int box_contour[480][2] = {0};
  60. for (auto &pt: seg_points) { // use 7.5~9ms
  61. pt.x -= x_alpha * merge_mat.cols / 2;
  62. pt.y -= y_alpha * merge_mat.rows / 2;
  63. if (box_contour[pt.y][0] == 0 && box_contour[pt.y][1] == 0) {
  64. box_contour[pt.y][0] = pt.x;
  65. box_contour[pt.y][1] = pt.x;
  66. } else {
  67. box_contour[pt.y][0] = MIN(pt.x, box_contour[pt.y][0]);
  68. box_contour[pt.y][1] = MAX(pt.x, box_contour[pt.y][1]);
  69. }
  70. }
  71. for (auto &ends: box_contour) {
  72. auto line = seg_results.mutable_boxes(device_index)->add_lines();
  73. line->set_begin(ends[0]);
  74. line->set_end(ends[1]);
  75. }
  76. }
  77. }
  78. char data[seg_results.ByteSizeLong()];
  79. seg_results.SerializeToArray((void *) data, seg_results.ByteSizeLong());
  80. tof_client->SendMessage("tof/seg", data, seg_results.ByteSizeLong());
  81. }
  82. return 1;
  83. }