TestUtils.h 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120
  1. //
  2. // TestUtils.h
  3. // MNN
  4. //
  5. // Created by MNN on 2019/01/15.
  6. // Copyright © 2018, Alibaba Group Holding Limited
  7. //
  8. #ifndef TestUtils_h
  9. #define TestUtils_h
  10. #include <assert.h>
  11. #include <stdio.h>
  12. #include <functional>
  13. #include <string>
  14. #include <MNN/MNNForwardType.h>
  15. #include <MNN/Tensor.hpp>
  16. #include <math.h>
  17. #include <iostream>
  18. #include "core/Backend.hpp"
  19. #include <MNN/expr/Executor.hpp>
  20. #include <MNN/expr/ExecutorScope.hpp>
  21. #include "MNN_generated.h"
  22. /**
  23. * @brief dispatch payload on all available backends
  24. * @param payload test to perform
  25. */
  26. void dispatch(std::function<void(MNNForwardType)> payload);
  27. /**
  28. * @brief dispatch payload on given backend
  29. * @param payload test to perform
  30. * @param backend given backend
  31. */
  32. void dispatch(std::function<void(MNNForwardType)> payload, MNNForwardType backend);
  33. /**
  34. @brief check the result with the ground truth
  35. @param result data
  36. @param rightData
  37. @param size
  38. @param threshold
  39. */
  40. template <typename T>
  41. bool checkVector(const T* result, const T* rightData, int size, T threshold){
  42. MNN_ASSERT(result != nullptr);
  43. MNN_ASSERT(rightData != nullptr);
  44. MNN_ASSERT(size >= 0);
  45. for(int i = 0; i < size; ++i){
  46. if(fabs(result[i] - rightData[i]) > threshold){
  47. std::cout << "No." << i << " error, right: " << rightData[i] << ", compute: " << result[i] << std::endl;
  48. return false;
  49. }
  50. }
  51. return true;
  52. }
  53. template <typename T>
  54. bool checkVectorByRelativeError(const T* result, const T* rightData, int size, float rtol) {
  55. MNN_ASSERT(result != nullptr);
  56. MNN_ASSERT(rightData != nullptr);
  57. MNN_ASSERT(size >= 0);
  58. float maxValue = 0.0f;
  59. for(int i = 0; i < size; ++i){
  60. maxValue = fmax(fabs(rightData[i]), maxValue);
  61. }
  62. float reltiveError = maxValue * rtol;
  63. for(int i = 0; i < size; ++i){
  64. if (fabs(result[i] - rightData[i]) > reltiveError) {
  65. std::cout << i << ": right: " << rightData[i] << ", compute: " << result[i] << std::endl;
  66. return false;
  67. }
  68. }
  69. return true;
  70. }
  71. template <typename T>
  72. bool checkVectorByRelativeError(const T* result, const T* rightData, const T* alterRightData, int size, float rtol) {
  73. MNN_ASSERT(result != nullptr);
  74. MNN_ASSERT(rightData != nullptr);
  75. MNN_ASSERT(size >= 0);
  76. float maxValue = 0.0f;
  77. for(int i = 0; i < size; ++i) {
  78. maxValue = fmax(fmax(fabs(rightData[i]), fabs(alterRightData[i])), maxValue);
  79. }
  80. float reltiveError = maxValue * rtol;
  81. for(int i = 0; i < size; ++i) {
  82. if (fabs(result[i] - rightData[i]) > reltiveError && fabs(result[i] - alterRightData[i]) > reltiveError) {
  83. std::cout << i << ": right: " << rightData[i] << " or " << alterRightData[i] << ", compute: " << result[i] << std::endl;
  84. return false;
  85. }
  86. }
  87. return true;
  88. }
  89. int getTestPrecision(MNNForwardType forwardType, MNN::BackendConfig::PrecisionMode precision, bool isSupportFp16);
  90. float convertFP32ToBF16(float fp32Value);
  91. float convertFP32ToFP16(float fp32Value);
  92. inline float keepFP32Precision(float fp32Value) {
  93. return fp32Value;
  94. }
  95. MNNForwardType getCurrentType();
  96. using ConvertFP32 = float(*)(float fp32Value);
  97. const static std::vector<ConvertFP32> FP32Converter = {
  98. keepFP32Precision,
  99. keepFP32Precision,
  100. #ifdef MNN_SUPPORT_BF16
  101. convertFP32ToBF16,
  102. #else
  103. keepFP32Precision,
  104. #endif
  105. convertFP32ToFP16
  106. };
  107. #endif /* TestUtils_h */