wheel-detector.cpp 1.3 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152
  1. #include "wheel-detector.h"
  2. TensorrtWheelDetector::TensorrtWheelDetector(const std::string &model_file, const std::string &class_file){
  3. yolov8_ = new Inference(model_file, cv::Size{640, 640}, class_file, false);
  4. }
  5. TensorrtWheelDetector::~TensorrtWheelDetector(){
  6. }
  7. bool TensorrtWheelDetector::detect(cv::Mat& img,std::vector<Object>& objs){
  8. std::vector<Detection> rets = yolov8_->runInference(img);
  9. for (auto &ret: rets) {
  10. Object obj;
  11. obj.rect = ret.box;
  12. obj.label = ret.class_id;
  13. obj.prob = ret.confidence;
  14. objs.push_back(obj);
  15. }
  16. return true;
  17. }
  18. bool TensorrtWheelDetector::detect(cv::Mat& img,std::vector<Object>& objs,cv::Mat& res){
  19. std::vector<Detection> rets = yolov8_->runInference(img);
  20. for (auto &ret: rets) {
  21. Object obj;
  22. obj.rect = ret.box;
  23. obj.label = ret.class_id;
  24. obj.prob = ret.confidence;
  25. obj.boxMask = img;
  26. objs.push_back(obj);
  27. }
  28. return true;
  29. }
  30. std::vector<cv::Point> TensorrtWheelDetector::getPointsFromObj(const Object &obj){
  31. std::vector<cv::Point> ret;
  32. int x=int(obj.rect.x+0.5);
  33. int y=int(obj.rect.y+0.5);
  34. int width=int(obj.rect.width);
  35. int height=int(obj.rect.height);
  36. for(int i=0;i<height;++i){
  37. for(int j=0;j<width;++j){
  38. ret.emplace_back(x+j,y+i);
  39. }
  40. }
  41. return ret;
  42. }