12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970 |
- #ifndef __3DCNN__LOCATER__HH_
- #define __3DCNN__LOCATER__HH_
- #include "opencv/highgui.h"
- #include "opencv2/opencv.hpp"
- #include <pcl/point_types.h>
- #include <pcl/common/common.h>
- #include <string>
- #include "../error_code/error_code.h"
- /*
- * 3dcnn网络识别车两中轮胎点
- */
- class Cnn3d_segmentation
- {
- public:
- Cnn3d_segmentation(int l,int w,int h,int freq,int nClass);
- //设置3dcnn网络参数
- Error_manager set_parameter(int l, int w, int h,int freq,int classes);
- virtual ~Cnn3d_segmentation();
- //初始化网络参数
- //weights:参数文件
- virtual Error_manager init(std::string weights);
- //预测
- //cloud:输入点云
- //rect:输出旋转矩形框
- //save_dir:中间文件保存路径
- virtual Error_manager predict(pcl::PointCloud<pcl::PointXYZ>::Ptr cloud,
- float& center_x,float& center_y,
- float& wheel_base,float& width,float& angle, std::string save_dir);
- protected:
- //根据设置的参数将点云转换成网络输入数格式
- float* generate_tensor(pcl::PointCloud<pcl::PointXYZ>::Ptr cloud, float min_x, float max_x,
- float min_y, float max_y, float min_z, float max_z);
- //将识别结果转换成点云
- std::vector<pcl::PointCloud<pcl::PointXYZRGB>::Ptr> decodeCloud(pcl::PointCloud<pcl::PointXYZ>& cloud,
- float* data, float min_x, float max_x, float min_y, float max_y, float min_z, float max_z);
- protected:
- //判断矩形框是否找到
- //points:输入轮子中心点, 只能是3或者4个
- Error_manager isRect(std::vector<cv::Point2f>& points);
- //kmeans算法聚类,拆分四个轮子,返回四轮中心
- //cloud:输入点云
- std::vector<pcl::PointCloud<pcl::PointXYZ>::Ptr> kmeans(
- pcl::PointCloud<pcl::PointXYZRGB>::Ptr cloud, std::string cluster_file_path);
- bool check_box(cv::RotatedRect& box, pcl::PointCloud<pcl::PointXYZ>::Ptr cloud);
- ///IIU判断
- bool check_IOU(cv::RotatedRect& box, pcl::PointCloud<pcl::PointXYZ>::Ptr cloud);
- /*
- * 对点云作pca分析,然后在最小轴方向过滤, 防止轮胎左右两边出现噪声
- * cloud_in:输入点云
- * cloud_out:输出点云
- */
- Error_manager pca_minist_axis_filter(pcl::PointCloud<pcl::PointXYZ>::Ptr cloud_in,
- pcl::PointCloud<pcl::PointXYZ>::Ptr cloud_out);
- /*
- *计算两点距离
- */
- static double distance(cv::Point2f p1, cv::Point2f p2);
- protected:
- int m_lenth;
- int m_width;
- int m_height;
- int m_freq;
- int m_nClass;
- };
- #endif
|