pictureRecognition_module.cpp 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110
  1. //
  2. // pictureRecognition_module.cpp
  3. // MNN
  4. //
  5. // Created by MNN on 2018/05/14.
  6. // Copyright © 2018, Alibaba Group Holding Limited
  7. //
  8. #include <stdio.h>
  9. #include <MNN/ImageProcess.hpp>
  10. #include <MNN/expr/Module.hpp>
  11. #include <MNN/expr/Executor.hpp>
  12. #include <MNN/expr/ExprCreator.hpp>
  13. #define MNN_OPEN_TIME_TRACE
  14. #include <algorithm>
  15. #include <fstream>
  16. #include <functional>
  17. #include <memory>
  18. #include <sstream>
  19. #include <vector>
  20. #include <MNN/AutoTime.hpp>
  21. #define STB_IMAGE_IMPLEMENTATION
  22. #include "stb_image.h"
  23. #include "stb_image_write.h"
  24. using namespace MNN::CV;
  25. int main(int argc, const char* argv[]) {
  26. if (argc < 3) {
  27. MNN_PRINT("Usage: ./pictureRecognition_module.out model.mnn input0.jpg input1.jpg input2.jpg ... \n");
  28. return 0;
  29. }
  30. // Load module with Config
  31. /*
  32. MNN::Express::Module::BackendInfo bnInfo;
  33. bnInfo.type = MNN_FORWARD_CPU;
  34. MNN::Express::Module::Config configs;
  35. configs.backend = &bnInfo;
  36. std::shared_ptr<MNN::Express::Module> net(MNN::Express::Module::load(std::vector<std::string>{}, std::vector<std::string>{}, argv[1], &configs));
  37. */
  38. // Load module with Runtime
  39. std::vector<MNN::ScheduleConfig> sConfigs;
  40. MNN::ScheduleConfig sConfig;
  41. sConfig.type = MNN_FORWARD_AUTO;
  42. sConfigs.push_back(sConfig);
  43. std::shared_ptr<MNN::Express::Executor::RuntimeManager> rtmgr = std::shared_ptr<MNN::Express::Executor::RuntimeManager>(MNN::Express::Executor::RuntimeManager::createRuntimeManager(sConfigs));
  44. if(rtmgr == nullptr) {
  45. MNN_ERROR("Empty RuntimeManger\n");
  46. return 0;
  47. }
  48. // Give cache full path which must be Readable and writable
  49. rtmgr->setCache(".cachefile");
  50. std::shared_ptr<MNN::Express::Module> net(MNN::Express::Module::load(std::vector<std::string>{}, std::vector<std::string>{}, argv[1], rtmgr));
  51. // Create Input
  52. int batchSize = argc - 2;
  53. auto input = MNN::Express::_Input({batchSize, 3, 224, 224}, MNN::Express::NC4HW4);
  54. for (int batch = 0; batch < batchSize; ++batch) {
  55. int size_w = 224;
  56. int size_h = 224;
  57. int bpp = 3;
  58. auto inputPatch = argv[batch + 2];
  59. int width, height, channel;
  60. auto inputImage = stbi_load(inputPatch, &width, &height, &channel, 4);
  61. if (nullptr == inputImage) {
  62. MNN_ERROR("Can't open %s\n", inputPatch);
  63. return 0;
  64. }
  65. MNN_PRINT("origin size: %d, %d\n", width, height);
  66. Matrix trans;
  67. // Set transform, from dst scale to src, the ways below are both ok
  68. trans.setScale((float)(width-1) / (size_w-1), (float)(height-1) / (size_h-1));
  69. ImageProcess::Config config;
  70. config.filterType = BILINEAR;
  71. float mean[3] = {103.94f, 116.78f, 123.68f};
  72. float normals[3] = {0.017f, 0.017f, 0.017f};
  73. // float mean[3] = {127.5f, 127.5f, 127.5f};
  74. // float normals[3] = {0.00785f, 0.00785f, 0.00785f};
  75. ::memcpy(config.mean, mean, sizeof(mean));
  76. ::memcpy(config.normal, normals, sizeof(normals));
  77. config.sourceFormat = RGBA;
  78. config.destFormat = BGR;
  79. std::shared_ptr<ImageProcess> pretreat(ImageProcess::create(config));
  80. pretreat->setMatrix(trans);
  81. // for NC4HW4, UP_DIV(3, 4) * 4 = 4
  82. pretreat->convert((uint8_t*)inputImage, width, height, 0, input->writeMap<float>() + batch * 4 * 224 * 224, 224, 224, 4, 0, halide_type_of<float>());
  83. stbi_image_free(inputImage);
  84. }
  85. auto outputs = net->onForward({input});
  86. auto output = MNN::Express::_Convert(outputs[0], MNN::Express::NHWC);
  87. output = MNN::Express::_Reshape(output, {0, -1});
  88. int topK = 10;
  89. auto topKV = MNN::Express::_TopKV2(output, MNN::Express::_Scalar<int>(topK));
  90. auto value = topKV[0]->readMap<float>();
  91. auto indice = topKV[1]->readMap<int>();
  92. for (int batch = 0; batch < batchSize; ++batch) {
  93. MNN_PRINT("For Input: %s \n", argv[batch+2]);
  94. for (int i=0; i<topK; ++i) {
  95. MNN_PRINT("%d, %f\n", indice[batch * topK + i], value[batch * topK + i]);
  96. }
  97. }
  98. rtmgr->updateCache();
  99. return 0;
  100. }