wheel-detector.cpp 2.0 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768
  1. #include "wheel-detector.h"
  2. TensorrtWheelDetector::TensorrtWheelDetector(const std::string &model_file, const std::string &class_file){
  3. cudaSetDevice(0);
  4. yolov8_ = new YOLOv8_seg(model_file);
  5. yolov8_->make_pipe(false);
  6. imgsz_=cv::Size{640, 480};
  7. seg_h_ = 120;
  8. seg_w_ = 160;
  9. seg_channels_ = 32;
  10. }
  11. TensorrtWheelDetector::~TensorrtWheelDetector(){
  12. if(yolov8_!=nullptr){
  13. delete yolov8_;
  14. yolov8_=nullptr;
  15. }
  16. }
  17. bool TensorrtWheelDetector::detect(cv::Mat& img,std::vector<Object>& objs){
  18. if(yolov8_==nullptr){
  19. return false;
  20. }
  21. if(img.size()!=imgsz_){
  22. printf("imgsz required [%d,%d],but input is [%d,%d]\n",imgsz_.height,imgsz_.width,img.rows,img.cols);
  23. return false;
  24. }
  25. yolov8_->copy_from_Mat(img, imgsz_);
  26. yolov8_->infer();
  27. float score_thres=0.9;
  28. float iou_thres=0.65;
  29. int topk=10;
  30. yolov8_->postprocess(objs, score_thres, iou_thres, topk, seg_channels_, seg_h_, seg_w_);
  31. return true;
  32. }
  33. bool TensorrtWheelDetector::detect(cv::Mat& img,std::vector<Object>& objs,cv::Mat& res){
  34. if(detect(img,objs))
  35. {
  36. const std::vector<std::string> classes={"none","wheel"};
  37. const std::vector<std::vector<unsigned int>> colors = {{0, 114, 189}, {0, 255, 0}};
  38. const std::vector<std::vector<unsigned int>> mask_colors = {{255, 56, 56}, {255, 0, 0}};
  39. yolov8_->draw_objects(img, res, objs, classes, colors, mask_colors);
  40. return true;
  41. }else{
  42. return false;
  43. }
  44. }
  45. std::vector<cv::Point> TensorrtWheelDetector::getPointsFromObj(Object obj){
  46. std::vector<cv::Point> ret;
  47. int x=int(obj.rect.x+0.5);
  48. int y=int(obj.rect.y+0.5);
  49. int width=int(obj.rect.width);
  50. int height=int(obj.rect.height);
  51. for(int i=0;i<height;++i){
  52. for(int j=0;j<width;++j){
  53. if(obj.boxMask.at<uchar>(i,j)!=0){
  54. ret.push_back(cv::Point((x+j),y+i));
  55. }
  56. }
  57. }
  58. return ret;
  59. }