MNNTestSuite.cpp 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113
  1. //
  2. // MNNTestSuite.cpp
  3. // MNN
  4. //
  5. // Created by MNN on 2019/01/10.
  6. // Copyright © 2018, Alibaba Group Holding Limited
  7. //
  8. #include <stdlib.h>
  9. #include <map>
  10. #include <algorithm>
  11. #include <MNN/AutoTime.hpp>
  12. #include "MNNTestSuite.h"
  13. MNNTestSuite* MNNTestSuite::gInstance = NULL;
  14. MNNTestSuite* MNNTestSuite::get() {
  15. if (gInstance == NULL)
  16. gInstance = new MNNTestSuite;
  17. return gInstance;
  18. }
  19. MNNTestSuite::~MNNTestSuite() {
  20. for (int i = 0; i < mTests.size(); ++i) {
  21. delete mTests[i];
  22. }
  23. mTests.clear();
  24. }
  25. void MNNTestSuite::add(MNNTestCase* test, const char* name) {
  26. test->name = name;
  27. mTests.push_back(test);
  28. }
  29. static void printTestResult(int wrong, int right, const char* flag) {
  30. MNN_PRINT("TEST_NAME_UNIT%s: 单元测试%s\nTEST_CASE_AMOUNT_UNIT%s: ", flag, flag, flag);
  31. MNN_PRINT("{\"blocked\":0,\"failed\":%d,\"passed\":%d,\"skipped\":0}\n", wrong, right);
  32. MNN_PRINT("TEST_CASE={\"name\":\"单元测试%s\",\"failed\":%d,\"passed\":%d}\n", flag, wrong, right);
  33. }
  34. int MNNTestSuite::run(const char* key, int precision, const char* flag) {
  35. if (key == NULL || strlen(key) == 0)
  36. return 0;
  37. std::vector<std::pair<std::string, float>> runTimes;
  38. auto suite = MNNTestSuite::get();
  39. std::string prefix = key;
  40. std::vector<std::string> wrongs;
  41. size_t runUnit = 0;
  42. for (int i = 0; i < suite->mTests.size(); ++i) {
  43. MNNTestCase* test = suite->mTests[i];
  44. if (test->name.find(prefix) == 0) {
  45. runUnit++;
  46. MNN_PRINT("\trunning %s.\n", test->name.c_str());
  47. MNN::Timer _t;
  48. auto res = test->run(precision);
  49. runTimes.emplace_back(std::make_pair(test->name, _t.durationInUs() / 1000.0f));
  50. if (!res) {
  51. wrongs.emplace_back(test->name);
  52. }
  53. }
  54. }
  55. std::sort(runTimes.begin(), runTimes.end(), [](const std::pair<std::string, float>& left, const std::pair<std::string, float>& right) {
  56. return left.second < right.second;
  57. });
  58. for (auto& iter : runTimes) {
  59. MNN_PRINT("%s cost time: %.3f ms\n", iter.first.c_str(), iter.second);
  60. }
  61. if (wrongs.empty()) {
  62. MNN_PRINT("√√√ all <%s> tests passed.\n", key);
  63. }
  64. for (auto& wrong : wrongs) {
  65. MNN_PRINT("Error: %s\n", wrong.c_str());
  66. }
  67. printTestResult(wrongs.size(), runUnit - wrongs.size(), flag);
  68. return wrongs.size();
  69. }
  70. int MNNTestSuite::runAll(int precision, const char* flag) {
  71. auto suite = MNNTestSuite::get();
  72. std::vector<std::string> wrongs;
  73. std::vector<std::pair<std::string, float>> runTimes;
  74. for (int i = 0; i < suite->mTests.size(); ++i) {
  75. MNNTestCase* test = suite->mTests[i];
  76. if (test->name.find("speed") != std::string::npos) {
  77. // Don't test for speed because cost
  78. continue;
  79. }
  80. if (test->name.find("model") != std::string::npos) {
  81. // Don't test for model because need resource
  82. continue;
  83. }
  84. MNN_PRINT("\trunning %s.\n", test->name.c_str());
  85. MNN::Timer _t;
  86. auto res = test->run(precision);
  87. runTimes.emplace_back(std::make_pair(test->name, _t.durationInUs() / 1000.0f));
  88. if (!res) {
  89. wrongs.emplace_back(test->name);
  90. }
  91. }
  92. std::sort(runTimes.begin(), runTimes.end(), [](const std::pair<std::string, float>& left, const std::pair<std::string, float>& right) {
  93. return left.second < right.second;
  94. });
  95. for (auto& iter : runTimes) {
  96. MNN_PRINT("%s cost time: %.3f ms\n", iter.first.c_str(), iter.second);
  97. }
  98. if (wrongs.empty()) {
  99. MNN_PRINT("√√√ all tests passed.\n");
  100. }
  101. for (auto& wrong : wrongs) {
  102. MNN_PRINT("Error: %s\n", wrong.c_str());
  103. }
  104. printTestResult(wrongs.size(), suite->mTests.size() - wrongs.size(), flag);
  105. return wrongs.size();
  106. }