VecHalf.hpp 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430
  1. //
  2. // VecHalf.hpp
  3. // MNN
  4. //
  5. // Created by MNN on 2021/01/26.
  6. // Copyright © 2018, Alibaba Group Holding Limited
  7. //
  8. #ifndef VecHalf_hpp
  9. #define VecHalf_hpp
  10. #include "core/Macro.h"
  11. #include <stdint.h>
  12. #include <array>
  13. #include <algorithm> // supply std::max and std::min
  14. namespace MNN {
  15. namespace Math {
  16. template <int N>
  17. struct VecHalf {
  18. using VecType = VecHalf<N>;
  19. std::array<float, N> value;
  20. VecType operator+(const VecType& lr) const {
  21. VecType dst;
  22. for (int i = 0; i < N; ++i) {
  23. dst.value[i] = value[i] + lr.value[i];
  24. }
  25. return dst;
  26. }
  27. VecType operator-(const VecType& lr) const {
  28. VecType dst;
  29. for (int i = 0; i < N; ++i) {
  30. dst.value[i] = value[i] - lr.value[i];
  31. }
  32. return dst;
  33. }
  34. VecType operator*(const VecType& lr) const {
  35. VecType dst;
  36. for (int i = 0; i < N; ++i) {
  37. dst.value[i] = value[i] * lr.value[i];
  38. }
  39. return dst;
  40. }
  41. VecType operator*(float lr) const {
  42. VecType dst;
  43. for (int i = 0; i < N; ++i) {
  44. dst.value[i] = value[i] * lr;
  45. }
  46. return dst;
  47. }
  48. VecType& operator=(const VecType& lr) {
  49. for (int i = 0; i < N; ++i) {
  50. value[i] = lr.value[i];
  51. }
  52. return *this;
  53. }
  54. VecType operator-() {
  55. VecType dst;
  56. for (int i = 0; i < N; ++i) {
  57. dst.value[i] = -value[i];
  58. }
  59. return dst;
  60. }
  61. VecHalf() {
  62. }
  63. VecHalf(const float v) {
  64. for (int i = 0; i < N; ++i) {
  65. value[i] = v;
  66. }
  67. }
  68. VecHalf(std::array<float, N>&& v) {
  69. value = std::move(v);
  70. }
  71. VecHalf(const VecType& lr) {
  72. for (int i = 0; i < N; ++i) {
  73. value[i] = lr.value[i];
  74. }
  75. }
  76. float operator[](size_t i) {
  77. return value[i];
  78. }
  79. static VecType broadcast(int16_t val) {
  80. VecType v;
  81. auto tempV = (int32_t*)v.value.data();
  82. for (int i = 0; i < N; ++i) {
  83. tempV[i] = val << 16;
  84. }
  85. return v;
  86. }
  87. static VecType load(const int16_t* addr) {
  88. VecType v;
  89. auto tempV = (int32_t*)v.value.data();
  90. for (int i = 0; i < N; ++i) {
  91. tempV[i] = addr[i] << 16;
  92. }
  93. return v;
  94. }
  95. static void save(int16_t* addr, const VecType& v) {
  96. auto tempV = (int32_t*)v.value.data();
  97. for (int i = 0; i < N; ++i) {
  98. addr[i] = tempV[i] >> 16;
  99. }
  100. }
  101. static VecType max(const VecType& v1, const VecType& v2) {
  102. VecType dst;
  103. for (int i = 0; i < N; ++i) {
  104. dst.value[i] = std::max(v1.value[i], v2.value[i]);
  105. }
  106. return dst;
  107. }
  108. static VecType min(const VecType& v1, const VecType& v2) {
  109. VecType dst;
  110. for (int i = 0; i < N; ++i) {
  111. dst.value[i] = std::min(v1.value[i], v2.value[i]);
  112. }
  113. return dst;
  114. }
  115. static inline void transpose4(VecType& vec0, VecType& vec1, VecType& vec2, VecType& vec3) {
  116. VecType source[4] = {vec0, vec1, vec2, vec3};
  117. for (int i = 0; i < N; ++i) {
  118. vec0.value[i] = source[i % 4].value[i >> 2];
  119. vec1.value[i] = source[i % 4].value[(i + N)>> 2];
  120. vec2.value[i] = source[i % 4].value[(i + 2 * N)>> 2];
  121. vec3.value[i] = source[i % 4].value[(i + 3 * N)>> 2];
  122. }
  123. }
  124. static inline void transpose12(int16_t* srcPtr, const size_t packCUnit) {
  125. MNN_ASSERT(false);
  126. }
  127. };
  128. #if defined(MNN_USE_SSE)
  129. #if defined(_MSC_VER)
  130. #include <intrin.h>
  131. #else
  132. #include <x86intrin.h>
  133. #endif
  134. template<>
  135. struct VecHalf<4> {
  136. using VecType = VecHalf<4>;
  137. __m128 value;
  138. VecType operator+(const VecType& lr) const {
  139. VecType dst = { _mm_add_ps(value, lr.value) };
  140. return dst;
  141. }
  142. VecType operator-(const VecType& lr) const {
  143. VecType dst = { _mm_sub_ps(value, lr.value) };
  144. return dst;
  145. }
  146. VecType operator*(const VecType& lr) const {
  147. VecType dst = { _mm_mul_ps(value, lr.value) };
  148. return dst;
  149. }
  150. VecType operator*(float lr) const {
  151. VecType dst = { _mm_mul_ps(value, _mm_set1_ps(lr)) };
  152. return dst;
  153. }
  154. VecType& operator=(const VecType& lr) {
  155. value = lr.value;
  156. return *this;
  157. }
  158. VecType operator-() {
  159. VecType dst;
  160. #if defined(_MSC_VER)
  161. dst.value = _mm_xor_ps(value, _mm_set1_ps(-0.f)); // Using unary operation to SSE vec is GCC extension. We can not do this directly in MSVC.
  162. #else
  163. dst.value = -value;
  164. #endif
  165. return dst;
  166. }
  167. VecHalf() {
  168. }
  169. VecHalf(const float v) {
  170. value = _mm_set1_ps(v);
  171. }
  172. VecHalf(__m128& v) {
  173. value = v;
  174. }
  175. VecHalf(__m128&& v) {
  176. value = std::move(v);
  177. }
  178. VecHalf(const VecType& lr) {
  179. value = lr.value;
  180. }
  181. VecHalf(VecType&& lr) {
  182. value = std::move(lr.value);
  183. }
  184. float operator[](size_t i) {
  185. #if defined(_MSC_VER) // X64 native only mandatory support SSE and SSE2 extension, and we can not find intrinsic function to extract element directly by index in SSE and SSE2 extension.
  186. float temp[4];
  187. _mm_storeu_ps(temp, value);
  188. return temp[i];
  189. #else
  190. return value[i];
  191. #endif
  192. }
  193. static VecType broadcast(int16_t val) {
  194. auto temp = _mm_set1_epi16(val);
  195. #ifndef MNN_SSE_USE_FP16_INSTEAD
  196. auto zero = _mm_xor_si128(temp, temp);
  197. auto res = _mm_castsi128_ps(_mm_unpacklo_epi16(zero, temp));
  198. #else
  199. auto res = _mm_cvtph_ps(temp);
  200. #endif
  201. VecType v = { std::move(res) };
  202. return v;
  203. }
  204. static VecType load(const int16_t* addr) {
  205. auto temp = _mm_loadl_epi64((__m128i*)addr);
  206. #ifndef MNN_SSE_USE_FP16_INSTEAD
  207. auto zero = _mm_xor_si128(temp, temp);
  208. auto res = _mm_castsi128_ps(_mm_unpacklo_epi16(zero, temp));
  209. #else
  210. auto res = _mm_cvtph_ps(temp);
  211. #endif
  212. VecType v = { std::move(res) };
  213. return v;
  214. }
  215. static void save(int16_t* addr, const VecType& v) {
  216. #ifndef MNN_SSE_USE_FP16_INSTEAD
  217. auto temp = _mm_castps_si128(v.value);
  218. temp = _mm_srai_epi32(temp, 16);
  219. temp = _mm_packs_epi32(temp, temp);
  220. #else
  221. static __m128 gMinValue = _mm_set1_ps(-32768);
  222. static __m128 gMaxValue = _mm_set1_ps(32767);
  223. auto t = _mm_max_ps(v.value, gMinValue);
  224. t = _mm_min_ps(t, gMaxValue);
  225. auto temp = _mm_cvtps_ph(t, 0x8);
  226. #endif
  227. _mm_storel_epi64((__m128i*)addr, temp);
  228. }
  229. static VecType max(const VecType& v1, const VecType& v2) {
  230. VecType dst = { _mm_max_ps(v1.value, v2.value) };
  231. return dst;
  232. }
  233. static VecType min(const VecType& v1, const VecType& v2) {
  234. VecType dst = { _mm_min_ps(v1.value, v2.value) };
  235. return dst;
  236. }
  237. static inline void transpose4(VecType& vec0, VecType& vec1, VecType& vec2, VecType& vec3) {
  238. __m128 tmp3, tmp2, tmp1, tmp0;
  239. tmp0 = _mm_unpacklo_ps((vec0.value), (vec1.value));
  240. tmp2 = _mm_unpacklo_ps((vec2.value), (vec3.value));
  241. tmp1 = _mm_unpackhi_ps((vec0.value), (vec1.value));
  242. tmp3 = _mm_unpackhi_ps((vec2.value), (vec3.value));
  243. vec0.value = _mm_movelh_ps(tmp0, tmp2);
  244. vec1.value = _mm_movehl_ps(tmp2, tmp0);
  245. vec2.value = _mm_movelh_ps(tmp1, tmp3);
  246. vec3.value = _mm_movehl_ps(tmp3, tmp1);
  247. }
  248. // x86 VecHalf transpose12 unused in any case
  249. static inline void transpose12(int16_t* srcPtr, const size_t packCUnit) {
  250. MNN_ASSERT(false);
  251. }
  252. };
  253. #endif
  254. #if defined(MNN_USE_NEON)
  255. #include <arm_neon.h>
  256. template<>
  257. struct VecHalf<4> {
  258. using VecType = VecHalf<4>;
  259. float32x4_t value;
  260. VecType operator+(const VecType& lr) const {
  261. VecType dst = { vaddq_f32(value, lr.value) };
  262. return dst;
  263. }
  264. VecType operator-(const VecType& lr) const {
  265. VecType dst = { vsubq_f32(value, lr.value) };
  266. return dst;
  267. }
  268. VecType operator*(const VecType& lr) const {
  269. VecType dst = { vmulq_f32(value, lr.value) };
  270. return dst;
  271. }
  272. VecType operator*(const float lr) const {
  273. VecType dst = { vmulq_f32(value, vdupq_n_f32(lr)) };
  274. return dst;
  275. }
  276. VecType& operator=(const VecType& lr) {
  277. value = lr.value;
  278. return *this;
  279. }
  280. VecType operator-() {
  281. VecType dst = { vnegq_f32(value) };
  282. return dst;
  283. }
  284. VecHalf() {
  285. }
  286. VecHalf(const float v) {
  287. value = vdupq_n_f32(v);
  288. }
  289. VecHalf(float32x4_t& v) {
  290. value = v;
  291. }
  292. VecHalf(float32x4_t&& v) {
  293. value = std::move(v);
  294. }
  295. VecHalf(const VecType& lr) {
  296. value = lr.value;
  297. }
  298. VecHalf(VecType&& lr) {
  299. value = std::move(lr.value);
  300. }
  301. float operator[](const int i) {
  302. // vgetq_lane_f32(value, i) does NOT work, i must be const number such as 0, 2,
  303. return value[i];
  304. }
  305. static VecType broadcast(int16_t val) {
  306. VecType dst = { vreinterpretq_f32_s32(vshll_n_s16(vdup_n_s16(val), 16)) };
  307. return dst;
  308. }
  309. static VecType load(const int16_t* addr) {
  310. // equivalent to this:
  311. // int16x4_t vec4s16 = vld1_s16(addr); // load bf16 data as fixed point data of 16-bit.
  312. // int32x4_t vec4s32 =vshll_n_s16(vec4s16, 16); // shift left 16bit as 32-bit data.
  313. // float32x4_t vec4f32 = vreinterpretq_f32_s32(vec4s32);// treat 32-bit fix point result as float32 data
  314. // VecType dest = { vec4f32 }; // construct a struct of VecType
  315. VecType dst = { vreinterpretq_f32_s32(vshll_n_s16(vld1_s16(addr), 16)) };
  316. return dst;
  317. }
  318. static void save(int16_t* addr, const VecType& v) {
  319. vst1_s16(addr, vshrn_n_s32(vreinterpretq_s32_f32(v.value), 16));
  320. return;
  321. }
  322. static VecType max(const VecType& v1, const VecType& v2) {
  323. VecType dst = { vmaxq_f32(v1.value, v2.value) };
  324. return dst;
  325. }
  326. static VecType min(const VecType& v1, const VecType& v2) {
  327. VecType dst = { vminq_f32(v1.value, v2.value) };
  328. return dst;
  329. }
  330. static inline void transpose4(VecType& vec0, VecType& vec1, VecType& vec2, VecType& vec3) {
  331. #ifdef __aarch64__
  332. auto m0 = vtrn1q_s32(vec0.value, vec1.value);
  333. auto m1 = vtrn2q_s32(vec0.value, vec1.value);
  334. auto m2 = vtrn1q_s32(vec2.value, vec3.value);
  335. auto m3 = vtrn2q_s32(vec2.value, vec3.value);
  336. vec0.value = vtrn1q_s64(m0, m2);
  337. vec1.value = vtrn1q_s64(m1, m3);
  338. vec2.value = vtrn2q_s64(m0, m2);
  339. vec3.value = vtrn2q_s64(m1, m3);
  340. #else
  341. auto m0m1 = vtrnq_s32(vec0.value, vec1.value);
  342. auto m2m3 = vtrnq_s32(vec2.value, vec3.value);
  343. vec0.value = m0m1.val[0];
  344. vec1.value = m0m1.val[1];
  345. vec2.value = m2m3.val[0];
  346. vec3.value = m2m3.val[1];
  347. vec0.value = vsetq_lane_s64(vgetq_lane_s64(m2m3.val[0], 0), vec0.value, 1);
  348. vec1.value = vsetq_lane_s64(vgetq_lane_s64(m2m3.val[1], 0), vec1.value, 1);
  349. vec2.value = vsetq_lane_s64(vgetq_lane_s64(m0m1.val[0], 1), vec2.value, 0);
  350. vec3.value = vsetq_lane_s64(vgetq_lane_s64(m0m1.val[1], 1), vec3.value, 0);
  351. /*
  352. generated arm32 assembly code is almost the same as:
  353. vtrn.32 d0, d2
  354. vtrn.32 d1, d3
  355. vtrn.32 d4, d6
  356. vtrn.32 d5, d7
  357. vswp d1, d4
  358. vswp d3, d6
  359. */
  360. #endif
  361. }
  362. static inline void transpose4(int16x4_t& vec0, int16x4_t& vec1, int16x4_t& vec2, int16x4_t& vec3) {
  363. auto trans0 = vtrn_s16(vec0, vec1);
  364. auto m0 = trans0.val[0];
  365. auto m1 = trans0.val[1];
  366. auto trans1 = vtrn_s16(vec2, vec3);
  367. auto m2 = trans1.val[0];
  368. auto m3 = trans1.val[1];
  369. auto trans2 = vtrn_s32(m0, m2);
  370. vec0 = trans2.val[0];
  371. vec2 = trans2.val[1];
  372. auto trans3 = vtrn_s32(m1, m3);
  373. vec1 = trans3.val[0];
  374. vec3 = trans3.val[1];
  375. }
  376. static inline void transpose12(int16_t* srcPtr, const size_t packCUnit) {
  377. auto s0 = vld1_s16(srcPtr + 0 * packCUnit);
  378. auto s3 = vld1_s16(srcPtr + 1 * packCUnit);
  379. auto s6 = vld1_s16(srcPtr + 2 * packCUnit);
  380. auto s9 = vld1_s16(srcPtr + 3 * packCUnit);
  381. auto s1 = vld1_s16(srcPtr + 4 * packCUnit);
  382. auto s4 = vld1_s16(srcPtr + 5 * packCUnit);
  383. auto s7 = vld1_s16(srcPtr + 6 * packCUnit);
  384. auto s10 = vld1_s16(srcPtr + 7 * packCUnit);
  385. auto s2 = vld1_s16(srcPtr + 8 * packCUnit);
  386. auto s5 = vld1_s16(srcPtr + 9 * packCUnit);
  387. auto s8 = vld1_s16(srcPtr + 10 * packCUnit);
  388. auto s11 = vld1_s16(srcPtr + 11 * packCUnit);
  389. transpose4(s0, s3, s6, s9);
  390. transpose4(s1, s4, s7, s10);
  391. transpose4(s2, s5, s8, s11);
  392. vst1_s16(srcPtr + 0 * packCUnit, s0);
  393. vst1_s16(srcPtr + 1 * packCUnit, s1);
  394. vst1_s16(srcPtr + 2 * packCUnit, s2);
  395. vst1_s16(srcPtr + 3 * packCUnit, s3);
  396. vst1_s16(srcPtr + 4 * packCUnit, s4);
  397. vst1_s16(srcPtr + 5 * packCUnit, s5);
  398. vst1_s16(srcPtr + 6 * packCUnit, s6);
  399. vst1_s16(srcPtr + 7 * packCUnit, s7);
  400. vst1_s16(srcPtr + 8 * packCUnit, s8);
  401. vst1_s16(srcPtr + 9 * packCUnit, s9);
  402. vst1_s16(srcPtr + 10 * packCUnit, s10);
  403. vst1_s16(srcPtr + 11 * packCUnit, s11);
  404. }
  405. };
  406. #endif
  407. }
  408. }
  409. #endif