Arm82Unary.cpp 8.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270
  1. //
  2. // Arm82Unary.cpp
  3. // MNN
  4. //
  5. // Created by MNN on 2018/08/02.
  6. // Copyright © 2018, Alibaba Group Holding Limited
  7. //
  8. #if defined(__ANDROID__) || defined(__aarch64__)
  9. #include <vector>
  10. #include <cmath>
  11. #include <algorithm>
  12. #include "Arm82Unary.hpp"
  13. #include "Arm82Backend.hpp"
  14. #include "core/Macro.h"
  15. #include "core/OpCommonUtils.hpp"
  16. #include "core/Concurrency.h"
  17. #include "backend/cpu/UnaryUtils.hpp"
  18. #include "Arm82OptFunc.hpp"
  19. #include "MNN_generated.h"
  20. #include <arm_neon.h>
  21. namespace MNN {
  22. struct VecSquare {
  23. float16x8_t operator()(float16x8_t &x) const {
  24. return x * x;
  25. }
  26. };
  27. struct VecRsqrt {
  28. float16x8_t operator()(float16x8_t &x) const {
  29. return vrsqrteq_f16(x);
  30. }
  31. };
  32. struct VecNeg {
  33. float16x8_t operator()(float16x8_t &x) const {
  34. return vnegq_f16(x);
  35. }
  36. };
  37. struct VecAbs {
  38. float16x8_t operator()(float16x8_t &x) const {
  39. return vabsq_f16(x);
  40. }
  41. };
  42. struct VecRecipocal {
  43. float16x8_t operator()(float16x8_t &x) const {
  44. return vrecpeq_f16(x);
  45. }
  46. };
  47. #if defined(__aarch64__)
  48. struct VecSqrt {
  49. float16x8_t operator()(float16x8_t &x) const {
  50. return vabsq_f16(x);
  51. }
  52. };
  53. #endif
  54. template<typename Compute>
  55. void FP16VecUnary(void *dstRaw, const void *src0Raw, int elementSize) {
  56. Compute Func;
  57. auto dst = (float16_t*)dstRaw;
  58. auto src0 = (const float16_t*)src0Raw;
  59. const int sizeDivUnit = elementSize / 8;
  60. const int remainCount = elementSize - sizeDivUnit * 8;
  61. if (sizeDivUnit > 0) {
  62. for (int i = 0; i < sizeDivUnit; ++i) {
  63. float16x8_t a = vld1q_f16(src0);
  64. vst1q_f16(dst, Func(a));
  65. src0 += 8;
  66. dst += 8;
  67. }
  68. }
  69. if (remainCount > 0) {
  70. float16_t tempSrc0[8];
  71. float16_t tempDst[8];
  72. ::memcpy(tempSrc0, src0, remainCount * sizeof(int16_t));
  73. float16x8_t a = vld1q_f16(tempSrc0);
  74. vst1q_f16(tempDst, Func(a));
  75. ::memcpy(dst, tempDst, remainCount * sizeof(int16_t));
  76. }
  77. }
  78. #define BLOCK_SIZE 16
  79. template<typename Compute>
  80. static void _Wrap(void* outRaw, const void* inpRaw, int realSize) {
  81. Compute execute;
  82. float out[BLOCK_SIZE];
  83. float inp[BLOCK_SIZE];
  84. int b = realSize / BLOCK_SIZE;
  85. int remain = realSize % BLOCK_SIZE;
  86. auto outR = (int16_t*)outRaw;
  87. auto inpR = (const int16_t*)inpRaw;
  88. for (int i=0; i<b; ++i) {
  89. MNNDequantizeFP16(inpR, inp, BLOCK_SIZE);
  90. execute(out, inp, BLOCK_SIZE);
  91. MNNQuantizeFP16(out, outR, BLOCK_SIZE);
  92. outR += BLOCK_SIZE;
  93. inpR += BLOCK_SIZE;
  94. }
  95. if (remain > 0) {
  96. MNNDequantizeFP16(inpR, inp, remain);
  97. execute(out, inp, remain);
  98. MNNQuantizeFP16(out, outR, remain);
  99. }
  100. }
  101. struct _Exp {
  102. void operator()(void* outRaw, const void* inpRaw, int realSize) const {
  103. auto out = (float*)outRaw;
  104. auto inp = (const float*)inpRaw;
  105. float offset[2] = {
  106. 1.0f,
  107. 0.0f
  108. };
  109. MNNExp(out, inp, offset, realSize);
  110. }
  111. };
  112. struct _ExpM1 {
  113. void operator()(void* outRaw, const void* inpRaw, int realSize) const {
  114. auto out = (float*)outRaw;
  115. auto inp = (const float*)inpRaw;
  116. float offset[2] = {
  117. 1.0f,
  118. -1.0f
  119. };
  120. MNNExp(out, inp, offset, realSize);
  121. }
  122. };
  123. struct _Tanh {
  124. void operator()(void* outRaw, const void* inpRaw, int realSize) const {
  125. auto out = (float*)outRaw;
  126. auto inp = (const float*)inpRaw;
  127. MNNTanh(out, inp, realSize);
  128. }
  129. };
  130. struct _Sigmoid {
  131. void operator()(void* outRaw, const void* inpRaw, int realSize) const {
  132. auto out = (float*)outRaw;
  133. auto inp = (const float*)inpRaw;
  134. MNNSigmoidLowp(out, inp, realSize);
  135. }
  136. };
  137. void FP16HardSwish(void* outRaw, const void* inpRaw, int realSize) {
  138. auto out = (FLOAT16*)outRaw;
  139. auto inp = (const FLOAT16*)inpRaw;
  140. int sizeC8 = realSize / 8;
  141. int sizeRemain = realSize % 8;
  142. if (sizeC8 > 0) {
  143. float16x8_t zero = vdupq_n_f16(0.f);
  144. float16x8_t three = vdupq_n_f16(3.f);
  145. float16x8_t six = vdupq_n_f16(6.f);
  146. float16x8_t divsix = vdupq_n_f16(1.0f/6.f);
  147. for (int i = 0; i < sizeC8; i++) {
  148. auto x = vld1q_f16(inp);
  149. auto y = vmulq_f16(vmulq_f16(x, vminq_f16(vmaxq_f16(vaddq_f16(x, three), zero), six)), divsix);
  150. vst1q_f16(out, y);
  151. out += 8;
  152. inp += 8;
  153. }
  154. }
  155. for (int i=0; i<sizeRemain; ++i) {
  156. auto x = inp[i];
  157. float16_t y;
  158. if (x <= -3) {
  159. y = 0;
  160. } else if (x >= 3) {
  161. y = x;
  162. } else {
  163. y = x * (x + 3) / 6;
  164. }
  165. out[i] = y;
  166. }
  167. }
  168. template <typename Func, typename T>
  169. struct _Unary {
  170. void operator()(void* outputPtr, const void* inputPtr, int elementSize) const {
  171. Func f;
  172. const T *inputData = (T*)inputPtr;
  173. T *outputData = (T *)outputPtr;
  174. for (int i=0; i<elementSize; ++i) {
  175. outputData[i] = f(inputData[i]);
  176. }
  177. }
  178. };
  179. MNNUnaryExecute Arm82Unary::select(int type, int precision) {
  180. switch (type) {
  181. case UnaryOpOperation_ABS:
  182. return FP16VecUnary<VecAbs>;
  183. case UnaryOpOperation_SQUARE:
  184. return FP16VecUnary<VecSquare>;
  185. case UnaryOpOperation_NEG:
  186. return FP16VecUnary<VecNeg>;
  187. case UnaryOpOperation_RSQRT:
  188. return FP16VecUnary<VecRsqrt>;
  189. case UnaryOpOperation_EXP:
  190. return _Wrap<_Exp>;
  191. case UnaryOpOperation_COS:
  192. return _Wrap<_Unary<UnaryCos<float>, float>>;
  193. case UnaryOpOperation_SIN:
  194. return _Wrap<_Unary<UnarySin<float>, float>>;
  195. case UnaryOpOperation_SIGMOID:
  196. return _Wrap<_Sigmoid>;
  197. case UnaryOpOperation_TANH:
  198. return _Wrap<_Tanh>;
  199. case UnaryOpOperation_TAN:
  200. return _Wrap<_Unary<UnaryTan<float>, float>>;
  201. case UnaryOpOperation_ATAN:
  202. return _Wrap<_Unary<UnaryATan<float>, float>>;
  203. #if defined(__aarch64__)
  204. case UnaryOpOperation_SQRT:
  205. return FP16VecUnary<VecSqrt>;
  206. #else
  207. case UnaryOpOperation_SQRT:
  208. return _Wrap<_Unary<UnarySqrt<float>, float>>;
  209. #endif
  210. case UnaryOpOperation_CEIL:
  211. return _Wrap<_Unary<UnaryCeil<float>, float>>;
  212. case UnaryOpOperation_RECIPROCAL:
  213. return FP16VecUnary<VecRecipocal>;
  214. case UnaryOpOperation_LOG1P:
  215. return _Wrap<_Unary<UnaryLog1p<float>, float>>;
  216. case UnaryOpOperation_LOG:
  217. return _Wrap<_Unary<UnaryLog<float>, float>>;
  218. case UnaryOpOperation_FLOOR:
  219. return _Wrap<_Unary<UnaryFloor<float>, float>>;
  220. case UnaryOpOperation_BNLL:
  221. return _Wrap<_Unary<UnaryBNLL<float>, float>>;
  222. case UnaryOpOperation_ACOSH:
  223. return _Wrap<_Unary<UnaryAcosh<float>, float>>;
  224. case UnaryOpOperation_SINH:
  225. return _Wrap<_Unary<UnarySinh<float>, float>>;
  226. case UnaryOpOperation_ASINH:
  227. return _Wrap<_Unary<UnaryAsinh<float>, float>>;
  228. case UnaryOpOperation_ATANH:
  229. return _Wrap<_Unary<UnaryAtanh<float>, float>>;
  230. case UnaryOpOperation_SIGN:
  231. return _Wrap<_Unary<UnarySign<float>, float>>;
  232. case UnaryOpOperation_ROUND:
  233. return _Wrap<_Unary<UnaryRound<float>, float>>;
  234. case UnaryOpOperation_COSH:
  235. return _Wrap<_Unary<UnaryCosh<float>, float>>;
  236. case UnaryOpOperation_ERF:
  237. return _Wrap<_Unary<UnaryErf<float>, float>>;
  238. case UnaryOpOperation_ERFC:
  239. return _Wrap<_Unary<UnaryErfc<float>, float>>;
  240. case UnaryOpOperation_ERFINV:
  241. return _Wrap<_Unary<UnaryErfinv<float>, float>>;
  242. case UnaryOpOperation_EXPM1:
  243. return _Wrap<_ExpM1>;
  244. case UnaryOpOperation_ASIN:
  245. return _Wrap<_Unary<UnaryAsin<float>, float>>;
  246. case UnaryOpOperation_ACOS:
  247. return _Wrap<_Unary<UnaryAcos<float>, float>>;
  248. case UnaryOpOperation_HARDSWISH:
  249. return FP16HardSwish;
  250. default:
  251. MNN_ASSERT(false);
  252. break;
  253. }
  254. return nullptr;
  255. }
  256. } // namespace MNN
  257. #endif