// This file is part of OpenCV project. // It is subject to the license terms in the LICENSE file found in the top-level directory // of this distribution and at http://opencv.org/license.html. #include "perf_precomp.hpp" namespace opencv_test { struct LstmParams { // Batch size int nrSamples; // Size of the input vector int inputSize; // Size of the internal state vector int hiddenSize; // Number of timesteps for the LSTM int nrSteps; }; static inline void PrintTo(const LstmParams& params, ::std::ostream* os) { (*os) << "BATCH=" << params.nrSamples << ", IN=" << params.inputSize << ", HIDDEN=" << params.hiddenSize << ", TS=" << params.nrSteps; } static const LstmParams testLstmConfigs[] = { {1, 192, 192, 100}, {1, 1024, 192, 100}, {1, 64, 192, 100}, {1, 192, 512, 100}, {64, 192, 192, 2}, {64, 1024, 192, 2}, {64, 64, 192, 2}, {64, 192, 512, 2}, {128, 192, 192, 2}, {128, 1024, 192, 2}, {128, 64, 192, 2}, {128, 192, 512, 2} }; class Layer_LSTM : public TestBaseWithParam {}; PERF_TEST_P_(Layer_LSTM, lstm) { const LstmParams& params = GetParam(); LayerParams lp; lp.type = "LSTM"; lp.name = "testLstm"; lp.set("produce_cell_output", false); lp.set("use_timestamp_dim", true); Mat weightH(params.hiddenSize * 4, params.hiddenSize, CV_32FC1, cv::Scalar(0)); Mat weightX(params.hiddenSize * 4, params.inputSize, CV_32FC1, cv::Scalar(0)); Mat bias(params.hiddenSize * 4, 1, CV_32FC1, cv::Scalar(0)); Mat hInternal(params.nrSteps, params.hiddenSize, CV_32FC1, cv::Scalar(0)); Mat cInternal(params.nrSteps, params.hiddenSize, CV_32FC1, cv::Scalar(0)); lp.blobs.push_back(weightH); lp.blobs.push_back(weightX); lp.blobs.push_back(bias); lp.blobs.push_back(hInternal); lp.blobs.push_back(cInternal); std::vector inputDims; inputDims.push_back(params.nrSamples); inputDims.push_back(params.nrSteps); inputDims.push_back(params.inputSize); Mat input(inputDims.size(), inputDims.data(), CV_32FC1); input = cv::Scalar(0); Net net; net.addLayerToPrev(lp.name, lp.type, lp); net.setInput(input); // Warm up std::vector outputs(2); net.forward(outputs, "testLstm"); TEST_CYCLE() { net.forward(outputs, "testLstm"); } SANITY_CHECK_NOTHING(); } INSTANTIATE_TEST_CASE_P(/**/, Layer_LSTM, testing::ValuesIn(testLstmConfigs)); } // namespace