#pragma once #include #include #ifndef LIB_SIFT_API #ifdef LIB_EXPORTS #if defined(_MSC_VER) #define LIB_SIFT_API __declspec(dllexport) #else #define LIB_SIFT_API __attribute__((visibility("default"))) #endif #else #if defined(_MSC_VER) #define LIB_SIFT_API #else #define LIB_SIFT_API #endif #endif #endif /* * PointSift 点云分割类 * 加载PointSift网络,分割点云 */ class LIB_SIFT_API PointSifter { public: //初始化PointSift //point_num:pointsift输入点数 //cls_num:PointSift输出类别数 PointSifter(int point_num, int cls_num); ~PointSifter(); //加载网络参数 //meta:tensorflow网络结构定义文件 //cpkt:tensorflow网络权重 bool Load(std::string meta, std::string cpkt); //预测 //data:输入数据,大小为 输入点数*3 //output:输出数据,大小为 输入点数*类别数 bool Predict(float* data, float* output); //错误原因 std::string LastError(); private: PointSifter(); protected: std::mutex m_mutex; std::string m_error; bool m_bInit; int m_point_num; int m_cls_num; void* m_sess; };