cnn3d_segmentation.h 2.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970
  1. #ifndef __3DCNN__LOCATER__HH_
  2. #define __3DCNN__LOCATER__HH_
  3. #include "opencv/highgui.h"
  4. #include "opencv2/opencv.hpp"
  5. #include <pcl/point_types.h>
  6. #include <pcl/common/common.h>
  7. #include <string>
  8. #include "../error_code/error_code.h"
  9. /*
  10. * 3dcnn网络识别车两中轮胎点
  11. */
  12. class Cnn3d_segmentation
  13. {
  14. public:
  15. Cnn3d_segmentation(int l,int w,int h,int freq,int nClass);
  16. //设置3dcnn网络参数
  17. Error_manager set_parameter(int l, int w, int h,int freq,int classes);
  18. virtual ~Cnn3d_segmentation();
  19. //初始化网络参数
  20. //weights:参数文件
  21. virtual Error_manager init(std::string weights);
  22. //预测
  23. //cloud:输入点云
  24. //rect:输出旋转矩形框
  25. //save_dir:中间文件保存路径
  26. virtual Error_manager predict(pcl::PointCloud<pcl::PointXYZ>::Ptr cloud,
  27. float& center_x,float& center_y,
  28. float& wheel_base,float& width,float& angle, std::string save_dir);
  29. protected:
  30. //根据设置的参数将点云转换成网络输入数格式
  31. float* generate_tensor(pcl::PointCloud<pcl::PointXYZ>::Ptr cloud, float min_x, float max_x,
  32. float min_y, float max_y, float min_z, float max_z);
  33. //将识别结果转换成点云
  34. std::vector<pcl::PointCloud<pcl::PointXYZRGB>::Ptr> decodeCloud(pcl::PointCloud<pcl::PointXYZ>& cloud,
  35. float* data, float min_x, float max_x, float min_y, float max_y, float min_z, float max_z);
  36. protected:
  37. //判断矩形框是否找到
  38. //points:输入轮子中心点, 只能是3或者4个
  39. Error_manager isRect(std::vector<cv::Point2f>& points);
  40. //kmeans算法聚类,拆分四个轮子,返回四轮中心
  41. //cloud:输入点云
  42. std::vector<pcl::PointCloud<pcl::PointXYZ>::Ptr> kmeans(
  43. pcl::PointCloud<pcl::PointXYZRGB>::Ptr cloud, std::string cluster_file_path);
  44. bool check_box(cv::RotatedRect& box, pcl::PointCloud<pcl::PointXYZ>::Ptr cloud);
  45. ///IIU判断
  46. bool check_IOU(cv::RotatedRect& box, pcl::PointCloud<pcl::PointXYZ>::Ptr cloud);
  47. /*
  48. * 对点云作pca分析,然后在最小轴方向过滤, 防止轮胎左右两边出现噪声
  49. * cloud_in:输入点云
  50. * cloud_out:输出点云
  51. */
  52. Error_manager pca_minist_axis_filter(pcl::PointCloud<pcl::PointXYZ>::Ptr cloud_in,
  53. pcl::PointCloud<pcl::PointXYZ>::Ptr cloud_out);
  54. /*
  55. *计算两点距离
  56. */
  57. static double distance(cv::Point2f p1, cv::Point2f p2);
  58. protected:
  59. int m_lenth;
  60. int m_width;
  61. int m_height;
  62. int m_freq;
  63. int m_nClass;
  64. };
  65. #endif