Browse Source

add mqtt async

LiuZe 1 year ago
commit
0eb002a57c
4 changed files with 162 additions and 0 deletions
  1. 5 0
      net_message_trans.cpp
  2. 32 0
      net_message_trans.h
  3. 82 0
      yolo_seg_mqtt_async.cpp
  4. 43 0
      yolo_seg_mqtt_async.h

+ 5 - 0
net_message_trans.cpp

@@ -0,0 +1,5 @@
+//
+// Created by zx on 2023/12/12.
+//
+
+#include "net_message_trans.h"

+ 32 - 0
net_message_trans.h

@@ -0,0 +1,32 @@
+#pragma once
+
+#include "proto/def.grpc.pb.h"
+#include <opencv2/opencv.hpp>
+#include "tool/log.hpp"
+
+class NetMessageTrans {
+public:
+    static void lableMat2Proto(int lable, cv::Mat &mat, JetStream::LabelImage &image) {
+        image.set_label(lable);
+        image.mutable_ir()->set_width(mat.cols);
+        image.mutable_ir()->set_height(mat.rows);
+        image.mutable_ir()->set_channel(mat.type());
+        image.mutable_ir()->set_data(mat.data, mat.dataend - mat.datastart);
+        // LOG(INFO) << "mat.dataend - mat.datastart = " << mat.dataend - mat.datastart;
+    }
+
+    static void Proto2lableMat(JetStream::LabelImage &image, int &lable, cv::Mat &mat) {
+        lable = image.label();
+        mat = cv::Mat(image.ir().height(), image.ir().width(), image.ir().channel(), (void*)image.ir().data().data());
+    }
+
+    static void getSegBoxData() {
+
+    }
+
+    static void data2SegBox() {
+
+    }
+};
+
+

+ 82 - 0
yolo_seg_mqtt_async.cpp

@@ -0,0 +1,82 @@
+#include "yolo_seg_mqtt_async.h"
+
+void YoloSegmentMqttAsyncClient::init(const std::string &file) {
+    m_message_arrived_callback_ = CloudDataArrived;
+
+    MqttAsyncClient::init_from_proto(file);
+
+#ifdef ENABLE_TENSORRT_DETECT
+    detector = new TensorrtWheelDetector(ETC_PATH PROJECT_NAME "/model.engine", ETC_PATH PROJECT_NAME "/class.txt");
+#else
+    detector = new TensorrtWheelDetector(ETC_PATH PROJECT_NAME "/model.onnx", ETC_PATH PROJECT_NAME "/class.txt");
+#endif
+}
+
+// 识别结果数据包含识别结果、时间序号、具体识别信息
+int YoloSegmentMqttAsyncClient::CloudDataArrived(void *client, char *topicName, int topicLen,
+                                           MQTTAsync_message *message) {
+    auto *tof_client = (YoloSegmentMqttAsyncClient *)client;
+    // TODO:测试
+    std::string topic = topicName;
+    if (topic == "tof/ir") {
+        // 数据反序列化
+        int lable;
+        cv::Mat merge_mat;
+        JetStream::LabelImage testimage2;
+        testimage2.ParseFromArray(message->payload, message->payloadlen);
+        NetMessageTrans::Proto2lableMat(testimage2, lable, merge_mat);
+
+        // 模型识别
+        std::vector<Object> objs;
+        cv::cvtColor(merge_mat, merge_mat, cv::COLOR_GRAY2RGB);
+        tof_client->detector->detect(merge_mat, objs, merge_mat);
+
+        // 分离识别结果
+        auto t = std::chrono::steady_clock::now();
+        JetStream::LabelYolo seg_results;
+        seg_results.set_label(lable);
+        seg_results.add_boxes();
+        seg_results.add_boxes();
+        seg_results.add_boxes();
+        seg_results.add_boxes();
+        for (auto iter = objs.begin(); iter != objs.end(); iter++) {
+            auto seg_points = tof_client->detector->getPointsFromObj(*iter);
+            int device_index = (int(iter->rect.x / 640) * 0x01) | (int(iter->rect.y / 480) << 1);
+            // 校验识别矩形框是否越界
+            int device_index_check = (int((iter->rect.x + iter->rect.width) / 640) * 0x01) | (int((iter->rect.y + iter->rect.height) / 480) << 1);
+            if (device_index != device_index_check) {
+                // TODO:存图
+                objs.erase(iter);
+                iter--;
+                continue;
+            }
+
+            if (iter->prob > 0.9) {
+                int box_contour[480][2] = {0};
+                int x_alpha = (device_index & 0x1);
+                int y_alpha = ((device_index & 0x2) >> 1);
+                for (auto &pt: seg_points) {    // use 7.5~9ms
+                    pt.x -= x_alpha * merge_mat.cols / 2;
+                    pt.y -= y_alpha * merge_mat.rows / 2;
+                    if (box_contour[pt.y][0] == 0 && box_contour[pt.y][1] == 0) {
+                        box_contour[pt.y][0] = pt.x;
+                        box_contour[pt.y][1] = pt.x;
+                    } else {
+                        box_contour[pt.y][0] = MIN(pt.x, box_contour[pt.y][0]);
+                        box_contour[pt.y][1] = MAX(pt.x, box_contour[pt.y][1]);
+                    }
+                }
+                for (auto & ends : box_contour) {
+                    auto line = seg_results.mutable_boxes(device_index)->add_lines();
+                    line->set_begin(ends[0]);
+                    line->set_end(ends[1]);
+                }
+            }
+        }
+
+        char data[seg_results.ByteSizeLong()];
+        seg_results.SerializeToArray((void *)data, seg_results.ByteSizeLong());
+        tof_client->SendMessage("tof/seg", data, seg_results.ByteSizeLong());
+    }
+    return 1;
+}

+ 43 - 0
yolo_seg_mqtt_async.h

@@ -0,0 +1,43 @@
+#pragma once
+
+#include <pcl/point_types.h>
+#include <pcl/point_cloud.h>
+
+#include "pahoc/mqtt_async.h"
+#include "tool/log.hpp"
+#include "net_message_trans.h"
+
+#ifdef ENABLE_TENSORRT_DETECT
+#include "detect/tensorrt_detect/wheel-detector.h"
+#else
+#include "detect/onnx_detect/wheel-detector.h"
+#endif
+
+#include "communication/data_buf_lable.hpp"
+
+class YoloSegmentMqttAsyncClient : public MqttAsyncClient {
+    friend class MqttAsyncClient;
+public:
+    static YoloSegmentMqttAsyncClient *iter() {
+        static YoloSegmentMqttAsyncClient *instance = nullptr;
+        if (instance == nullptr) {
+            instance = new YoloSegmentMqttAsyncClient();
+        }
+        return instance;
+    }
+
+    void init(const std::string &file);
+
+protected:
+
+private:
+    YoloSegmentMqttAsyncClient() = default;
+
+    static int CloudDataArrived(void *client, char *topicName, int topicLen, MQTTAsync_message *message);
+
+public:
+    TensorrtWheelDetector *detector = nullptr;
+
+private:
+
+};