digits_lenet.cpp 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182
  1. // This example provides a digital recognition based on LeNet-5 and connected component analysis.
  2. // It makes it possible for OpenCV beginner to run dnn models in real time using only CPU.
  3. // It can read pictures from the camera in real time to make predictions, and display the recognized digits as overlays on top of the original digits.
  4. //
  5. // In order to achieve a better display effect, please write the number on white paper and occupy the entire camera.
  6. //
  7. // You can follow the following guide to train LeNet-5 by yourself using the MNIST dataset.
  8. // https://github.com/intel/caffe/blob/a3d5b022fe026e9092fc7abc7654b1162ab9940d/examples/mnist/readme.md
  9. //
  10. // You can also download already trained model directly.
  11. // https://github.com/zihaomu/opencv_digit_text_recognition_demo/tree/master/src
  12. #include <opencv2/imgproc.hpp>
  13. #include <opencv2/highgui.hpp>
  14. #include <opencv2/dnn.hpp>
  15. #include <iostream>
  16. #include <vector>
  17. using namespace cv;
  18. using namespace cv::dnn;
  19. const char *keys =
  20. "{ help h | | Print help message. }"
  21. "{ input i | | Path to input image or video file. Skip this argument to capture frames from a camera.}"
  22. "{ device | 0 | camera device number. }"
  23. "{ modelBin | | Path to a binary .caffemodel file contains trained network.}"
  24. "{ modelTxt | | Path to a .prototxt file contains the model definition of trained network.}"
  25. "{ width | 640 | Set the width of the camera }"
  26. "{ height | 480 | Set the height of the camera }"
  27. "{ thr | 0.7 | Confidence threshold. }";
  28. // Find best class for the blob (i.e. class with maximal probability)
  29. static void getMaxClass(const Mat &probBlob, int &classId, double &classProb);
  30. void predictor(Net net, const Mat &roi, int &class_id, double &probability);
  31. int main(int argc, char **argv)
  32. {
  33. // Parse command line arguments.
  34. CommandLineParser parser(argc, argv, keys);
  35. if (argc == 1 || parser.has("help"))
  36. {
  37. parser.printMessage();
  38. return 0;
  39. }
  40. int vWidth = parser.get<int>("width");
  41. int vHeight = parser.get<int>("height");
  42. float confThreshold = parser.get<float>("thr");
  43. std::string modelTxt = parser.get<String>("modelTxt");
  44. std::string modelBin = parser.get<String>("modelBin");
  45. Net net;
  46. try
  47. {
  48. net = readNet(modelTxt, modelBin);
  49. }
  50. catch (cv::Exception &ee)
  51. {
  52. std::cerr << "Exception: " << ee.what() << std::endl;
  53. std::cout << "Can't load the network by using the flowing files:" << std::endl;
  54. std::cout << "modelTxt: " << modelTxt << std::endl;
  55. std::cout << "modelBin: " << modelBin << std::endl;
  56. return 1;
  57. }
  58. const std::string resultWinName = "Please write the number on white paper and occupy the entire camera.";
  59. const std::string preWinName = "Preprocessing";
  60. namedWindow(preWinName, WINDOW_AUTOSIZE);
  61. namedWindow(resultWinName, WINDOW_AUTOSIZE);
  62. Mat labels, stats, centroids;
  63. Point position;
  64. Rect getRectangle;
  65. bool ifDrawingBox = false;
  66. int classId = 0;
  67. double probability = 0;
  68. Rect basicRect = Rect(0, 0, vWidth, vHeight);
  69. Mat rawImage;
  70. double fps = 0;
  71. // Open a video file or an image file or a camera stream.
  72. VideoCapture cap;
  73. if (parser.has("input"))
  74. cap.open(parser.get<String>("input"));
  75. else
  76. cap.open(parser.get<int>("device"));
  77. TickMeter tm;
  78. while (waitKey(1) < 0)
  79. {
  80. cap >> rawImage;
  81. if (rawImage.empty())
  82. {
  83. waitKey();
  84. break;
  85. }
  86. tm.reset();
  87. tm.start();
  88. Mat image = rawImage.clone();
  89. // Image preprocessing
  90. cvtColor(image, image, COLOR_BGR2GRAY);
  91. GaussianBlur(image, image, Size(3, 3), 2, 2);
  92. adaptiveThreshold(image, image, 255, ADAPTIVE_THRESH_MEAN_C, THRESH_BINARY, 25, 10);
  93. bitwise_not(image, image);
  94. Mat element = getStructuringElement(MORPH_RECT, Size(3, 3), Point(-1,-1));
  95. dilate(image, image, element, Point(-1,-1), 1);
  96. // Find connected component
  97. int nccomps = cv::connectedComponentsWithStats(image, labels, stats, centroids);
  98. for (int i = 1; i < nccomps; i++)
  99. {
  100. ifDrawingBox = false;
  101. // Extend the bounding box of connected component for easier recognition
  102. if (stats.at<int>(i - 1, CC_STAT_AREA) > 80 && stats.at<int>(i - 1, CC_STAT_AREA) < 3000)
  103. {
  104. ifDrawingBox = true;
  105. int left = stats.at<int>(i - 1, CC_STAT_HEIGHT) / 4;
  106. getRectangle = Rect(stats.at<int>(i - 1, CC_STAT_LEFT) - left, stats.at<int>(i - 1, CC_STAT_TOP) - left, stats.at<int>(i - 1, CC_STAT_WIDTH) + 2 * left, stats.at<int>(i - 1, CC_STAT_HEIGHT) + 2 * left);
  107. getRectangle &= basicRect;
  108. }
  109. if (ifDrawingBox && !getRectangle.empty())
  110. {
  111. Mat roi = image(getRectangle);
  112. predictor(net, roi, classId, probability);
  113. if (probability < confThreshold)
  114. continue;
  115. rectangle(rawImage, getRectangle, Scalar(128, 255, 128), 2);
  116. position = Point(getRectangle.br().x - 7, getRectangle.br().y + 25);
  117. putText(rawImage, std::to_string(classId), position, 3, 1.0, Scalar(128, 128, 255), 2);
  118. }
  119. }
  120. tm.stop();
  121. fps = 1 / tm.getTimeSec();
  122. std::string fpsString = format("Inference FPS: %.2f.", fps);
  123. putText(rawImage, fpsString, Point(5, 20), FONT_HERSHEY_SIMPLEX, 0.6, Scalar(128, 255, 128));
  124. imshow(resultWinName, rawImage);
  125. imshow(preWinName, image);
  126. }
  127. return 0;
  128. }
  129. static void getMaxClass(const Mat &probBlob, int &classId, double &classProb)
  130. {
  131. Mat probMat = probBlob.reshape(1, 1);
  132. Point classNumber;
  133. minMaxLoc(probMat, NULL, &classProb, NULL, &classNumber);
  134. classId = classNumber.x;
  135. }
  136. void predictor(Net net, const Mat &roi, int &classId, double &probability)
  137. {
  138. Mat pred;
  139. // Convert Mat to batch of images
  140. Mat inputBlob = dnn::blobFromImage(roi, 1.0, Size(28, 28));
  141. // Set the network input
  142. net.setInput(inputBlob);
  143. // Compute output
  144. pred = net.forward();
  145. getMaxClass(pred, classId, probability);
  146. }