GemmSparse.cpp 33 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776
  1. //
  2. // GemmCommon.cpp
  3. // MNN
  4. //
  5. // Created by MNN on 2021/07/28.
  6. // Copyright © 2018, Alibaba Group Holding Limited
  7. //
  8. #include "GemmCommon.hpp"
  9. #include "FunctionSummary.hpp"
  10. #include "Vec8.hpp"
  11. #include "core/Macro.h"
  12. #ifdef MNN_X86_USE_ASM
  13. extern "C" {
  14. void _AVX_MNNPackedSparseMatMulEpx4EFMA_ASM(SparseMatMulParas* temp, const float* bias, const size_t* parameter, const float* postParameters);
  15. void _AVX_MNNPackedSparseMatMulEpx1EFMA_ASM(SparseMatMulParas* temp, const float* bias, const size_t* parameter, const float* postParameters);
  16. }
  17. #endif
  18. void _AVX_MNNGetSparseMatMulPackMode(int* eP, int *lP, int* hP){
  19. *eP = 24;
  20. *lP = 1;
  21. *hP = 4;
  22. // hp is corresponding to sparse block along right matrix colum dimension. in ramdom sparse, it is 1.
  23. return;
  24. }
  25. #define EMULATED_AVX2_FMA(dst, src0, src1) dst = _mm256_add_ps(dst, _mm256_mul_ps(src0, src1));
  26. #define MIN_MAX_VEC(cVec) \
  27. cVec = _mm256_max_ps(cVec, minVec); \
  28. cVec = _mm256_min_ps(cVec, maxVec);
  29. #define ONE_H_STORE_E24(cTilePtr) \
  30. cTilePtr[8 * 0] = c0VecPtr[0]; \
  31. cTilePtr[8 * 1] = c0VecPtr[1]; \
  32. cTilePtr[8 * 2] = c0VecPtr[2]; \
  33. cTilePtr[8 * 3] = c0VecPtr[3]; \
  34. cTilePtr[8 * 4] = c0VecPtr[4]; \
  35. cTilePtr[8 * 5] = c0VecPtr[5]; \
  36. cTilePtr[8 * 6] = c0VecPtr[6]; \
  37. cTilePtr[8 * 7] = c0VecPtr[7]; \
  38. \
  39. cTilePtr[8 * 8] = c1VecPtr[0]; \
  40. cTilePtr[8 * 9] = c1VecPtr[1]; \
  41. cTilePtr[8 * 10] = c1VecPtr[2]; \
  42. cTilePtr[8 * 11] = c1VecPtr[3]; \
  43. cTilePtr[8 * 12] = c1VecPtr[4]; \
  44. cTilePtr[8 * 13] = c1VecPtr[5]; \
  45. cTilePtr[8 * 14] = c1VecPtr[6]; \
  46. cTilePtr[8 * 15] = c1VecPtr[7]; \
  47. \
  48. cTilePtr[8 * 16] = c2VecPtr[0]; \
  49. cTilePtr[8 * 17] = c2VecPtr[1]; \
  50. cTilePtr[8 * 18] = c2VecPtr[2]; \
  51. cTilePtr[8 * 19] = c2VecPtr[3]; \
  52. cTilePtr[8 * 20] = c2VecPtr[4]; \
  53. cTilePtr[8 * 21] = c2VecPtr[5]; \
  54. cTilePtr[8 * 22] = c2VecPtr[6]; \
  55. cTilePtr[8 * 23] = c2VecPtr[7];
  56. #define TRANSPOSE_4x4_WITH_STORE(rowIdx, offset, cVec0, cVec1, cVec2, cVec3, cTilePtr) \
  57. { \
  58. transposeTemp0 = _mm256_extractf128_ps(cVec0, offset); \
  59. transposeTemp1 = _mm256_extractf128_ps(cVec1, offset); \
  60. transposeTemp2 = _mm256_extractf128_ps(cVec2, offset); \
  61. transposeTemp3 = _mm256_extractf128_ps(cVec3, offset); \
  62. _MM_TRANSPOSE4_PS(transposeTemp0, transposeTemp1, transposeTemp2, transposeTemp3); \
  63. _mm_store_ps(cTilePtr + (rowIdx + 0) * unit, transposeTemp0); \
  64. _mm_store_ps(cTilePtr + (rowIdx + 1) * unit, transposeTemp1); \
  65. _mm_store_ps(cTilePtr + (rowIdx + 2) * unit, transposeTemp2); \
  66. _mm_store_ps(cTilePtr + (rowIdx + 3) * unit, transposeTemp3); \
  67. }
  68. #define TRANSPOSE_4x24_WITH_STORE(cTilePtr, unit) \
  69. { \
  70. __m128 transposeTemp0; \
  71. __m128 transposeTemp1; \
  72. __m128 transposeTemp2; \
  73. __m128 transposeTemp3; \
  74. TRANSPOSE_4x4_WITH_STORE(0, 0, c0Vec, c3Vec, c6Vec, c9Vec, cTilePtr); \
  75. TRANSPOSE_4x4_WITH_STORE(4, 1, c0Vec, c3Vec, c6Vec, c9Vec, cTilePtr); \
  76. TRANSPOSE_4x4_WITH_STORE(8, 0, c1Vec, c4Vec, c7Vec, c10Vec, cTilePtr); \
  77. TRANSPOSE_4x4_WITH_STORE(12, 1, c1Vec, c4Vec, c7Vec, c10Vec, cTilePtr); \
  78. TRANSPOSE_4x4_WITH_STORE(16, 0, c2Vec, c5Vec, c8Vec, c11Vec, cTilePtr); \
  79. TRANSPOSE_4x4_WITH_STORE(20, 1, c2Vec, c5Vec, c8Vec, c11Vec, cTilePtr); \
  80. }
  81. #define REMAIN_TRANSPOSE_4x24_WITH_STORE(cTilePtr, unit) \
  82. { \
  83. __m128 transposeTemp0; \
  84. __m128 transposeTemp1; \
  85. __m128 transposeTemp2; \
  86. __m128 transposeTemp3; \
  87. int tailE = eSize % 4; \
  88. int eFull4 = eSize / 4; \
  89. switch (eFull4) { \
  90. case 5: \
  91. TRANSPOSE_4x4_WITH_STORE(16, 0, c2Vec, c5Vec, c8Vec, c11Vec, cTilePtr); \
  92. case 4: \
  93. TRANSPOSE_4x4_WITH_STORE(12, 1, c1Vec, c4Vec, c7Vec, c10Vec, cTilePtr); \
  94. case 3: \
  95. TRANSPOSE_4x4_WITH_STORE(8, 0, c1Vec, c4Vec, c7Vec, c10Vec, cTilePtr); \
  96. case 2: \
  97. TRANSPOSE_4x4_WITH_STORE(4, 1, c0Vec, c3Vec, c6Vec, c9Vec, cTilePtr); \
  98. case 1: \
  99. TRANSPOSE_4x4_WITH_STORE(0, 0, c0Vec, c3Vec, c6Vec, c9Vec, cTilePtr); \
  100. default: \
  101. break; \
  102. } \
  103. if (tailE) { \
  104. if (eFull4 == 5) { \
  105. transposeTemp0 = _mm256_extractf128_ps(c2Vec, 1); \
  106. transposeTemp1 = _mm256_extractf128_ps(c5Vec, 1); \
  107. transposeTemp2 = _mm256_extractf128_ps(c8Vec, 1); \
  108. transposeTemp3 = _mm256_extractf128_ps(c11Vec, 1); \
  109. } else if (eFull4 == 4) { \
  110. transposeTemp0 = _mm256_extractf128_ps(c2Vec, 0); \
  111. transposeTemp1 = _mm256_extractf128_ps(c5Vec, 0); \
  112. transposeTemp2 = _mm256_extractf128_ps(c8Vec, 0); \
  113. transposeTemp3 = _mm256_extractf128_ps(c11Vec, 0); \
  114. } else if (eFull4 == 3) { \
  115. transposeTemp0 = _mm256_extractf128_ps(c1Vec, 1); \
  116. transposeTemp1 = _mm256_extractf128_ps(c4Vec, 1); \
  117. transposeTemp2 = _mm256_extractf128_ps(c7Vec, 1); \
  118. transposeTemp3 = _mm256_extractf128_ps(c10Vec, 1); \
  119. } else if (eFull4 == 2) { \
  120. transposeTemp0 = _mm256_extractf128_ps(c1Vec, 0); \
  121. transposeTemp1 = _mm256_extractf128_ps(c4Vec, 0); \
  122. transposeTemp2 = _mm256_extractf128_ps(c7Vec, 0); \
  123. transposeTemp3 = _mm256_extractf128_ps(c10Vec, 0); \
  124. } else if (eFull4 == 1) { \
  125. transposeTemp0 = _mm256_extractf128_ps(c0Vec, 1); \
  126. transposeTemp1 = _mm256_extractf128_ps(c3Vec, 1); \
  127. transposeTemp2 = _mm256_extractf128_ps(c6Vec, 1); \
  128. transposeTemp3 = _mm256_extractf128_ps(c9Vec, 1); \
  129. } \
  130. else{\
  131. transposeTemp0 = _mm256_extractf128_ps(c0Vec, 0); \
  132. transposeTemp1 = _mm256_extractf128_ps(c3Vec, 0); \
  133. transposeTemp2 = _mm256_extractf128_ps(c6Vec, 0); \
  134. transposeTemp3 = _mm256_extractf128_ps(c9Vec, 0); \
  135. }\
  136. _MM_TRANSPOSE4_PS(transposeTemp0, transposeTemp1, transposeTemp2, transposeTemp3); \
  137. int offset = 4 * eFull4; \
  138. switch (tailE) { \
  139. case 3: \
  140. _mm_storeu_ps(cTilePtr + (offset + 2) * unit, transposeTemp2); \
  141. case 2: \
  142. _mm_storeu_ps(cTilePtr + (offset + 1) * unit, transposeTemp1); \
  143. case 1: \
  144. _mm_storeu_ps(cTilePtr + (offset + 0) * unit, transposeTemp0); \
  145. default: \
  146. break; \
  147. } \
  148. } \
  149. }
  150. #define FP32_BYTES 4
  151. #define AVX2_SPARSE_EP 24
  152. #define AVX2_SP_BLOCK4 4
  153. void _AVX_MNNPackedSparseMatMulEpx1EFMA(float* C, const float* A, const float* B, size_t eSize, const size_t* parameter,
  154. const float* postParameters, const float* bias, unsigned int* NNZMap,
  155. int* dataOffsetMap) {
  156. /*
  157. mat_a: [eSize/eP, l, eP]
  158. mat_c: [h/unit, e, unit]
  159. bias: [h, ]
  160. parameter[0]: eP * bytes
  161. parameter[1]: l
  162. parameter[2]: h
  163. parameter[3]: h/unit stride, equals to e * unit * sizeof(dataType)
  164. parameter[4]: unit
  165. eSize: this tile`s real e size, which can be greater or less than eP!
  166. postParameters[2]: min_val of output
  167. postParameters[3]: max_val of output
  168. */
  169. /*
  170. This func performs the sparse matmul with bias add and post process of min/max threshold.
  171. The basic process of the dense version of func is:
  172. batch_matmul([l, eP], [h/hP, l, hP]) --> [h/hP, eP, hP].
  173. However, when mat_b is sparsed encoded, this func changes accordingly.
  174. First, divide the whole process into two part, the full hP part and the remain part.
  175. The full hP part means, in each iteration, mat_b`s col (or row actually) is processed in hP count,
  176. and the non-zero value is hP continous encoded.
  177. The remain part means, in each iteration, mat_b`s col (or row actually) is processed in 1 count,
  178. and the non-zero value is encoded one by one.
  179. (Although this func is specialized for hP = 1)
  180. ***********************************************
  181. Specialization description:
  182. 1. eP = 24, hP = 1, lP = 1;
  183. 2. mat_a stores in [eSize/eP, l, eP] format;
  184. 3. mat_c stores in [h/unit, e, unit] format;
  185. 4. data type is fixed as float32, which means the bytes = 4;
  186. 5. unit is fixed as 8;
  187. ***********************************************
  188. Note that, the function reserves the aStride, which is for mat_a that contains more than one l * eP tile.
  189. But for now, limit the eSize <= eP!
  190. */
  191. #ifdef MNN_X86_USE_ASM
  192. if (eSize == AVX2_SPARSE_EP && parameter[2] % 4 == 0){
  193. // use the asm function when eSize == 24 and h == 4x
  194. SparseMatMulParas temp = {C, A, B, NNZMap, dataOffsetMap};
  195. SparseMatMulParas* tempPtr = &temp;
  196. _AVX_MNNPackedSparseMatMulEpx1EFMA_ASM(tempPtr, bias, parameter, postParameters);
  197. return;
  198. }
  199. #endif
  200. const size_t aStride = parameter[0] / FP32_BYTES;
  201. const size_t l = parameter[1];
  202. const size_t h = parameter[2];
  203. const size_t cStride = parameter[3] / FP32_BYTES; // intrinsic do not need the byte stride.
  204. const size_t unit = 8;
  205. MNN_ASSERT(eSize <= aStride);
  206. auto minVec = _mm256_broadcast_ss(postParameters + 2);
  207. auto maxVec = _mm256_broadcast_ss(postParameters + 3);
  208. // full [l, eP] X [h/unit, e, unit]
  209. for (int matALoopIdx = 0; matALoopIdx < eSize / aStride; matALoopIdx++) {
  210. const float* aTilePtrSt = A + l * aStride * matALoopIdx;
  211. const int* aRowOffsetPtr = dataOffsetMap;
  212. const float* weightPtr = B;
  213. // as this func is specialized for hP = 1,
  214. // iteration in h axis is all full hP method.
  215. __m256 c0Vec;
  216. __m256 c1Vec;
  217. __m256 c2Vec;
  218. auto c0VecPtr = (float*)&c0Vec;
  219. auto c1VecPtr = (float*)&c1Vec;
  220. auto c2VecPtr = (float*)&c2Vec;
  221. for (int hLoopIdx = 0; hLoopIdx < h; hLoopIdx++) {
  222. float* cTilePtrSt = C + (unit * aStride * matALoopIdx) + (hLoopIdx / unit * cStride) + (hLoopIdx % unit);
  223. size_t nonZeroCnt = *NNZMap;
  224. NNZMap++;
  225. // inittialize mat_c tile with bias if existed.
  226. // [eP, hP] bias initialize.
  227. if (bias != nullptr) {
  228. c0Vec = _mm256_broadcast_ss(bias + hLoopIdx);
  229. c1Vec = c0Vec;
  230. c2Vec = c0Vec;
  231. } else {
  232. c0Vec = _mm256_setzero_ps();
  233. c1Vec = _mm256_setzero_ps();
  234. c2Vec = _mm256_setzero_ps();
  235. }
  236. for (int lLoopIdx = 0; lLoopIdx < nonZeroCnt; lLoopIdx++) {
  237. aTilePtrSt += aRowOffsetPtr[0];
  238. aRowOffsetPtr++;
  239. auto a0Vec = _mm256_loadu_ps(aTilePtrSt + 0);
  240. auto a1Vec = _mm256_loadu_ps(aTilePtrSt + 8);
  241. auto a2Vec = _mm256_loadu_ps(aTilePtrSt + 16);
  242. auto b0Vec = _mm256_broadcast_ss(weightPtr);
  243. weightPtr++;
  244. c0Vec = EMULATED_AVX2_FMA(c0Vec, a0Vec, b0Vec);
  245. c1Vec = EMULATED_AVX2_FMA(c1Vec, a1Vec, b0Vec);
  246. c2Vec = EMULATED_AVX2_FMA(c2Vec, a2Vec, b0Vec);
  247. }
  248. // min-max post process and store process.
  249. MIN_MAX_VEC(c0Vec);
  250. MIN_MAX_VEC(c1Vec);
  251. MIN_MAX_VEC(c2Vec);
  252. ONE_H_STORE_E24(cTilePtrSt);
  253. }
  254. NNZMap -= h;
  255. }
  256. // remained [l, eSize%eP] X [h/unit, e, unit]
  257. A += (eSize / aStride) * aStride * l;
  258. C += (eSize / aStride) * aStride * unit;
  259. eSize = eSize % aStride; // eSize % 24
  260. // remained eSize part
  261. if (eSize) {
  262. // as this func is specialized for hP = 1,
  263. // iteration in h axis is all full hP method.
  264. __m256 c0Vec;
  265. __m256 c1Vec;
  266. __m256 c2Vec;
  267. auto c0VecPtr = (float*)&c0Vec;
  268. auto c1VecPtr = (float*)&c1Vec;
  269. auto c2VecPtr = (float*)&c2Vec;
  270. for (int hLoopIdx = 0; hLoopIdx < h; hLoopIdx++) {
  271. float* cTilePtrSt = C + (hLoopIdx / unit * cStride) + (hLoopIdx % unit);
  272. size_t nonZeroCnt = *NNZMap;
  273. NNZMap++;
  274. // inittialize mat_c tile with bias if existed.
  275. // [eP, hP] bias initialize.
  276. if (bias != nullptr) {
  277. c0Vec = _mm256_broadcast_ss(bias + hLoopIdx);
  278. c1Vec = c0Vec;
  279. c2Vec = c0Vec;
  280. } else {
  281. c0Vec = _mm256_setzero_ps();
  282. c1Vec = _mm256_setzero_ps();
  283. c2Vec = _mm256_setzero_ps();
  284. }
  285. for (int lLoopIdx = 0; lLoopIdx < nonZeroCnt; lLoopIdx++) {
  286. A += dataOffsetMap[0];
  287. dataOffsetMap++;
  288. auto a0Vec = _mm256_loadu_ps(A + 0);
  289. auto a1Vec = _mm256_loadu_ps(A + 8);
  290. auto a2Vec = _mm256_loadu_ps(A + 16);
  291. auto b0Vec = _mm256_broadcast_ss(B);
  292. B++;
  293. c0Vec = EMULATED_AVX2_FMA(c0Vec, a0Vec, b0Vec);
  294. c1Vec = EMULATED_AVX2_FMA(c1Vec, a1Vec, b0Vec);
  295. c2Vec = EMULATED_AVX2_FMA(c2Vec, a2Vec, b0Vec);
  296. }
  297. // min-max post process and store process.
  298. MIN_MAX_VEC(c0Vec);
  299. MIN_MAX_VEC(c1Vec);
  300. MIN_MAX_VEC(c2Vec);
  301. auto CStorePtr = cTilePtrSt;
  302. auto cxVecPtr = c0VecPtr;
  303. if (eSize >= 8) {
  304. CStorePtr[8 * 0] = cxVecPtr[0];
  305. CStorePtr[8 * 1] = cxVecPtr[1];
  306. CStorePtr[8 * 2] = cxVecPtr[2];
  307. CStorePtr[8 * 3] = cxVecPtr[3];
  308. CStorePtr[8 * 4] = cxVecPtr[4];
  309. CStorePtr[8 * 5] = cxVecPtr[5];
  310. CStorePtr[8 * 6] = cxVecPtr[6];
  311. CStorePtr[8 * 7] = cxVecPtr[7];
  312. CStorePtr += 8 * unit;
  313. cxVecPtr = c1VecPtr;
  314. }
  315. if (eSize >= 16){
  316. CStorePtr[8 * 0] = cxVecPtr[0];
  317. CStorePtr[8 * 1] = cxVecPtr[1];
  318. CStorePtr[8 * 2] = cxVecPtr[2];
  319. CStorePtr[8 * 3] = cxVecPtr[3];
  320. CStorePtr[8 * 4] = cxVecPtr[4];
  321. CStorePtr[8 * 5] = cxVecPtr[5];
  322. CStorePtr[8 * 6] = cxVecPtr[6];
  323. CStorePtr[8 * 7] = cxVecPtr[7];
  324. CStorePtr += 8 * unit;
  325. cxVecPtr = c2VecPtr;
  326. }
  327. for (int i = 0; i < eSize % 8; i++) {
  328. CStorePtr[8 * i] = cxVecPtr[i];
  329. }
  330. }
  331. NNZMap -= h;
  332. }
  333. return;
  334. }
  335. void _AVX_MNNPackedSparseMatMulEpx4EFMA(float* C, const float* A, const float* B, size_t eSize, const size_t* parameter,
  336. const float* postParameters, const float* bias, unsigned int* NNZMap,
  337. int* dataOffsetMap) {
  338. /*
  339. mat_a: [eSize/eP, l, eP]
  340. mat_c: [h/unit, e, unit]
  341. bias: [h, ]
  342. parameter[0]: eP * bytes
  343. parameter[1]: l
  344. parameter[2]: h
  345. parameter[3]: h/unit stride, equals to e * unit * sizeof(dataType)
  346. parameter[4]: unit
  347. eSize: this tile`s real e size, which can be greater or less than eP!
  348. postParameters[2]: min_val of output
  349. postParameters[3]: max_val of output
  350. */
  351. /*
  352. This func performs the sparse matmul with bias add and post process of min/max threshold.
  353. The basic process of the dense version of func is:
  354. batch_matmul([l, eP], [h/hP, l, hP]) --> [h/hP, eP, hP].
  355. However, when mat_b is sparsed encoded, this func changes accordingly.
  356. First, divide the whole process into two part, the full hP part and the remain part.
  357. The full hP part means, in each iteration, mat_b`s col (or row actually) is processed in hP count,
  358. and the non-zero value is hP continous encoded.
  359. The remain part means, in each iteration, mat_b`s col (or row actually) is processed in 1 count,
  360. and the non-zero value is encoded one by one.
  361. ***********************************************
  362. Specialization description:
  363. 1. eP = 24, hP = 4, lP = 1;
  364. 2. mat_a stores in [eSize/eP, l, eP] format;
  365. 3. mat_c stores in [h/unit, e, unit] format;
  366. 4. data type is fixed as float32, which means the bytes = 4;
  367. 5. unit is fixed as 8;
  368. ***********************************************
  369. Note that, the function reserves the aStride, which is for mat_a that contains more than one l * eP tile.
  370. But for now, limit the eSize <= eP!
  371. */
  372. #define ONE_LP_ACT_E24(cVecFirst, cVecSecond, cVecThird) \
  373. b0Vec = _mm256_broadcast_ss(weightPtr); \
  374. weightPtr++; \
  375. cVecFirst = EMULATED_AVX2_FMA(cVecFirst, a0Vec, b0Vec); \
  376. cVecSecond = EMULATED_AVX2_FMA(cVecSecond, a1Vec, b0Vec); \
  377. cVecThird = EMULATED_AVX2_FMA(cVecThird, a2Vec, b0Vec);
  378. #define REMAIN_E_ONE_LP_ACT_E24(cVecFirst, cVecSecond, cVecThird) \
  379. b0Vec = _mm256_broadcast_ss(B); \
  380. B++; \
  381. cVecFirst = EMULATED_AVX2_FMA(cVecFirst, a0Vec, b0Vec); \
  382. cVecSecond = EMULATED_AVX2_FMA(cVecSecond, a1Vec, b0Vec); \
  383. cVecThird = EMULATED_AVX2_FMA(cVecThird, a2Vec, b0Vec);
  384. #ifdef MNN_X86_USE_ASM
  385. if (eSize == AVX2_SPARSE_EP && parameter[2] % 4 == 0){
  386. // use the asm function when eSize == eP(24) and h == 4x
  387. SparseMatMulParas temp = {C, A, B, NNZMap, dataOffsetMap};
  388. SparseMatMulParas* tempPtr = &temp;
  389. _AVX_MNNPackedSparseMatMulEpx4EFMA_ASM(tempPtr, bias, parameter, postParameters);
  390. return;
  391. }
  392. #endif
  393. const size_t aStride = parameter[0] / FP32_BYTES; // intrinsic do not need the byte stride.
  394. const size_t l = parameter[1];
  395. const size_t h = parameter[2];
  396. const size_t cStride = parameter[3] / FP32_BYTES; // intrinsic do not need the byte stride.
  397. const size_t unit = 8;
  398. MNN_ASSERT(eSize <= aStride);
  399. const float minVal = postParameters[2];
  400. const float maxVal = postParameters[3];
  401. const int fullHCnt = h / AVX2_SP_BLOCK4 * AVX2_SP_BLOCK4;
  402. // full [l, eP] X [h/unit, e, unit]
  403. for (int matALoopIdx = 0; matALoopIdx < eSize / aStride; matALoopIdx++) {
  404. const float* aTilePtrSt = A + l * aStride * matALoopIdx;
  405. const int* aRowOffsetPtr = dataOffsetMap;
  406. const float* weightPtr = B;
  407. int hLoopIdx = 0;
  408. // full hP method!
  409. for (; hLoopIdx < fullHCnt; hLoopIdx += AVX2_SP_BLOCK4) {
  410. float* cTilePtrSt = C + (unit * aStride * matALoopIdx) + (hLoopIdx / unit * cStride) + (hLoopIdx % unit);
  411. size_t nonZeroCnt = *NNZMap;
  412. NNZMap++;
  413. __m256 c0Vec;
  414. __m256 c1Vec;
  415. __m256 c2Vec;
  416. __m256 c3Vec;
  417. __m256 c4Vec;
  418. __m256 c5Vec;
  419. __m256 c6Vec;
  420. __m256 c7Vec;
  421. __m256 c8Vec;
  422. __m256 c9Vec;
  423. __m256 c10Vec;
  424. __m256 c11Vec;
  425. if (bias != nullptr) {
  426. c0Vec = _mm256_broadcast_ss(bias + hLoopIdx);
  427. c3Vec = _mm256_broadcast_ss(bias + hLoopIdx + 1);
  428. c6Vec = _mm256_broadcast_ss(bias + hLoopIdx + 2);
  429. c9Vec = _mm256_broadcast_ss(bias + hLoopIdx + 3);
  430. c1Vec = c0Vec;
  431. c2Vec = c0Vec;
  432. c4Vec = c3Vec;
  433. c5Vec = c3Vec;
  434. c7Vec = c6Vec;
  435. c8Vec = c6Vec;
  436. c10Vec = c9Vec;
  437. c11Vec = c9Vec;
  438. } else {
  439. // [intrinsic bug] zeroall will not work after the first iteration!
  440. c0Vec = _mm256_setzero_ps();
  441. c3Vec = _mm256_setzero_ps();
  442. c6Vec = _mm256_setzero_ps();
  443. c9Vec = _mm256_setzero_ps();
  444. c1Vec = _mm256_setzero_ps();
  445. c2Vec = _mm256_setzero_ps();
  446. c4Vec = _mm256_setzero_ps();
  447. c5Vec = _mm256_setzero_ps();
  448. c7Vec = _mm256_setzero_ps();
  449. c8Vec = _mm256_setzero_ps();
  450. c10Vec = _mm256_setzero_ps();
  451. c11Vec = _mm256_setzero_ps();
  452. }
  453. {
  454. __m256 a0Vec;
  455. __m256 a1Vec;
  456. __m256 a2Vec;
  457. __m256 b0Vec;
  458. for (int lLoopIdx = 0; lLoopIdx < nonZeroCnt; lLoopIdx++) {
  459. //printf("aRowOffset: %d\t", *aRowOffsetPtr);
  460. aTilePtrSt += *aRowOffsetPtr;
  461. aRowOffsetPtr++;
  462. a0Vec = _mm256_loadu_ps(aTilePtrSt + 0);
  463. a1Vec = _mm256_loadu_ps(aTilePtrSt + 8);
  464. a2Vec = _mm256_loadu_ps(aTilePtrSt + 16);
  465. ONE_LP_ACT_E24(c0Vec, c1Vec, c2Vec);
  466. ONE_LP_ACT_E24(c3Vec, c4Vec, c5Vec);
  467. ONE_LP_ACT_E24(c6Vec, c7Vec, c8Vec);
  468. ONE_LP_ACT_E24(c9Vec, c10Vec, c11Vec);
  469. }
  470. }
  471. {
  472. auto minVec = _mm256_set1_ps(minVal);
  473. auto maxVec = _mm256_set1_ps(maxVal);
  474. MIN_MAX_VEC(c0Vec);
  475. MIN_MAX_VEC(c1Vec);
  476. MIN_MAX_VEC(c2Vec);
  477. MIN_MAX_VEC(c3Vec);
  478. MIN_MAX_VEC(c4Vec);
  479. MIN_MAX_VEC(c5Vec);
  480. MIN_MAX_VEC(c6Vec);
  481. MIN_MAX_VEC(c7Vec);
  482. MIN_MAX_VEC(c8Vec);
  483. MIN_MAX_VEC(c9Vec);
  484. MIN_MAX_VEC(c10Vec);
  485. MIN_MAX_VEC(c11Vec);
  486. }
  487. TRANSPOSE_4x24_WITH_STORE(cTilePtrSt, unit);
  488. }
  489. // remain hP method!
  490. __m256 c0Vec;
  491. __m256 c1Vec;
  492. __m256 c2Vec;
  493. auto minVec = _mm256_set1_ps(minVal);
  494. auto maxVec = _mm256_set1_ps(maxVal);
  495. auto c0VecPtr = (float*)&c0Vec;
  496. auto c1VecPtr = (float*)&c1Vec;
  497. auto c2VecPtr = (float*)&c2Vec;
  498. for (; hLoopIdx < h; hLoopIdx++) {
  499. float* cTilePtrSt = C + (unit * aStride * matALoopIdx) + (hLoopIdx / unit * cStride) + (hLoopIdx % unit);
  500. size_t nonZeroCnt = *NNZMap;
  501. NNZMap++;
  502. // inittialize mat_c tile with bias if existed.
  503. // [eP, hP] bias initialize.
  504. if (bias != nullptr) {
  505. c0Vec = _mm256_broadcast_ss(bias + hLoopIdx);
  506. c1Vec = c0Vec;
  507. c2Vec = c0Vec;
  508. } else {
  509. c0Vec = _mm256_setzero_ps();
  510. c1Vec = _mm256_setzero_ps();
  511. c2Vec = _mm256_setzero_ps();
  512. }
  513. for (int lLoopIdx = 0; lLoopIdx < nonZeroCnt; lLoopIdx++) {
  514. aTilePtrSt += aRowOffsetPtr[0];
  515. aRowOffsetPtr++;
  516. auto a0Vec = _mm256_loadu_ps(aTilePtrSt + 0);
  517. auto a1Vec = _mm256_loadu_ps(aTilePtrSt + 8);
  518. auto a2Vec = _mm256_loadu_ps(aTilePtrSt + 16);
  519. auto b0Vec = _mm256_broadcast_ss(weightPtr);
  520. weightPtr++;
  521. c0Vec = EMULATED_AVX2_FMA(c0Vec, a0Vec, b0Vec);
  522. c1Vec = EMULATED_AVX2_FMA(c1Vec, a1Vec, b0Vec);
  523. c2Vec = EMULATED_AVX2_FMA(c2Vec, a2Vec, b0Vec);
  524. }
  525. // min-max post process and store process.
  526. MIN_MAX_VEC(c0Vec);
  527. MIN_MAX_VEC(c1Vec);
  528. MIN_MAX_VEC(c2Vec);
  529. ONE_H_STORE_E24(cTilePtrSt);
  530. }
  531. NNZMap -= fullHCnt / AVX2_SP_BLOCK4 + h - fullHCnt;
  532. }
  533. // remained [l, eSize%eP] X [h/unit, e, unit]
  534. A += (eSize / aStride) * aStride * l;
  535. C += (eSize / aStride) * aStride * unit;
  536. eSize = eSize % aStride; // eSize % 24
  537. // remained eSize part
  538. if (eSize) {
  539. int hLoopIdx = 0;
  540. for (; hLoopIdx < fullHCnt; hLoopIdx += AVX2_SP_BLOCK4) {
  541. float* cTilePtrSt = C + (hLoopIdx / unit * cStride) + (hLoopIdx % unit);
  542. size_t nonZeroCnt = *NNZMap;
  543. NNZMap++;
  544. __m256 c0Vec;
  545. __m256 c1Vec;
  546. __m256 c2Vec;
  547. __m256 c3Vec;
  548. __m256 c4Vec;
  549. __m256 c5Vec;
  550. __m256 c6Vec;
  551. __m256 c7Vec;
  552. __m256 c8Vec;
  553. __m256 c9Vec;
  554. __m256 c10Vec;
  555. __m256 c11Vec;
  556. if (bias != nullptr) {
  557. c0Vec = _mm256_broadcast_ss(bias + hLoopIdx);
  558. c3Vec = _mm256_broadcast_ss(bias + hLoopIdx + 1);
  559. c6Vec = _mm256_broadcast_ss(bias + hLoopIdx + 2);
  560. c9Vec = _mm256_broadcast_ss(bias + hLoopIdx + 3);
  561. c1Vec = c0Vec;
  562. c2Vec = c0Vec;
  563. c4Vec = c3Vec;
  564. c5Vec = c3Vec;
  565. c7Vec = c6Vec;
  566. c8Vec = c6Vec;
  567. c10Vec = c9Vec;
  568. c11Vec = c9Vec;
  569. } else {
  570. // [intrinsic bug] zeroall will not work after the first iteration!
  571. c0Vec = _mm256_setzero_ps();
  572. c3Vec = _mm256_setzero_ps();
  573. c6Vec = _mm256_setzero_ps();
  574. c9Vec = _mm256_setzero_ps();
  575. c1Vec = _mm256_setzero_ps();
  576. c2Vec = _mm256_setzero_ps();
  577. c4Vec = _mm256_setzero_ps();
  578. c5Vec = _mm256_setzero_ps();
  579. c7Vec = _mm256_setzero_ps();
  580. c8Vec = _mm256_setzero_ps();
  581. c10Vec = _mm256_setzero_ps();
  582. c11Vec = _mm256_setzero_ps();
  583. }
  584. {
  585. __m256 a0Vec;
  586. __m256 a1Vec;
  587. __m256 a2Vec;
  588. __m256 b0Vec;
  589. for (int lLoopIdx = 0; lLoopIdx < nonZeroCnt; lLoopIdx++) {
  590. A += *dataOffsetMap;
  591. dataOffsetMap++;
  592. a0Vec = _mm256_loadu_ps(A + 0);
  593. a1Vec = _mm256_loadu_ps(A + 8);
  594. a2Vec = _mm256_loadu_ps(A + 16);
  595. REMAIN_E_ONE_LP_ACT_E24(c0Vec, c1Vec, c2Vec);
  596. REMAIN_E_ONE_LP_ACT_E24(c3Vec, c4Vec, c5Vec);
  597. REMAIN_E_ONE_LP_ACT_E24(c6Vec, c7Vec, c8Vec);
  598. REMAIN_E_ONE_LP_ACT_E24(c9Vec, c10Vec, c11Vec);
  599. }
  600. }
  601. {
  602. auto minVec = _mm256_set1_ps(minVal);
  603. auto maxVec = _mm256_set1_ps(maxVal);
  604. MIN_MAX_VEC(c0Vec);
  605. MIN_MAX_VEC(c1Vec);
  606. MIN_MAX_VEC(c2Vec);
  607. MIN_MAX_VEC(c3Vec);
  608. MIN_MAX_VEC(c4Vec);
  609. MIN_MAX_VEC(c5Vec);
  610. MIN_MAX_VEC(c6Vec);
  611. MIN_MAX_VEC(c7Vec);
  612. MIN_MAX_VEC(c8Vec);
  613. MIN_MAX_VEC(c9Vec);
  614. MIN_MAX_VEC(c10Vec);
  615. MIN_MAX_VEC(c11Vec);
  616. }
  617. REMAIN_TRANSPOSE_4x24_WITH_STORE(cTilePtrSt, unit);
  618. }
  619. // remained h part
  620. __m256 c0Vec;
  621. __m256 c1Vec;
  622. __m256 c2Vec;
  623. auto c0VecPtr = (float*)&c0Vec;
  624. auto c1VecPtr = (float*)&c1Vec;
  625. auto c2VecPtr = (float*)&c2Vec;
  626. auto minVec = _mm256_set1_ps(minVal);
  627. auto maxVec = _mm256_set1_ps(maxVal);
  628. for (; hLoopIdx < h; hLoopIdx++) {
  629. float* cTilePtrSt = C + (hLoopIdx / unit * cStride) + (hLoopIdx % unit);
  630. size_t nonZeroCnt = *NNZMap;
  631. NNZMap++;
  632. // inittialize mat_c tile with bias if existed.
  633. // [eP, hP] bias initialize.
  634. if (bias != nullptr) {
  635. c0Vec = _mm256_broadcast_ss(bias + hLoopIdx);
  636. c1Vec = c0Vec;
  637. c2Vec = c0Vec;
  638. } else {
  639. c0Vec = _mm256_setzero_ps();
  640. c1Vec = _mm256_setzero_ps();
  641. c2Vec = _mm256_setzero_ps();
  642. }
  643. __m256 a0Vec;
  644. __m256 a1Vec;
  645. __m256 a2Vec;
  646. for (int lLoopIdx = 0; lLoopIdx < nonZeroCnt; lLoopIdx++) {
  647. A += *dataOffsetMap;
  648. dataOffsetMap++;
  649. a0Vec = _mm256_loadu_ps(A + 0);
  650. a1Vec = _mm256_loadu_ps(A + 8);
  651. a2Vec = _mm256_loadu_ps(A + 16);
  652. auto b0Vec = _mm256_broadcast_ss(B);
  653. B++;
  654. EMULATED_AVX2_FMA(c0Vec, a0Vec, b0Vec);
  655. EMULATED_AVX2_FMA(c1Vec, a1Vec, b0Vec);
  656. EMULATED_AVX2_FMA(c2Vec, a2Vec, b0Vec);
  657. }
  658. // min-max post process and store process.
  659. MIN_MAX_VEC(c0Vec);
  660. MIN_MAX_VEC(c1Vec);
  661. MIN_MAX_VEC(c2Vec);
  662. auto CStorePtr = cTilePtrSt;
  663. auto cxVecPtr = c0VecPtr;
  664. if (eSize >= 8) {
  665. CStorePtr[8 * 0] = cxVecPtr[0];
  666. CStorePtr[8 * 1] = cxVecPtr[1];
  667. CStorePtr[8 * 2] = cxVecPtr[2];
  668. CStorePtr[8 * 3] = cxVecPtr[3];
  669. CStorePtr[8 * 4] = cxVecPtr[4];
  670. CStorePtr[8 * 5] = cxVecPtr[5];
  671. CStorePtr[8 * 6] = cxVecPtr[6];
  672. CStorePtr[8 * 7] = cxVecPtr[7];
  673. CStorePtr += 8 * unit;
  674. cxVecPtr = c1VecPtr;
  675. }
  676. if (eSize >= 16){
  677. CStorePtr[8 * 0] = cxVecPtr[0];
  678. CStorePtr[8 * 1] = cxVecPtr[1];
  679. CStorePtr[8 * 2] = cxVecPtr[2];
  680. CStorePtr[8 * 3] = cxVecPtr[3];
  681. CStorePtr[8 * 4] = cxVecPtr[4];
  682. CStorePtr[8 * 5] = cxVecPtr[5];
  683. CStorePtr[8 * 6] = cxVecPtr[6];
  684. CStorePtr[8 * 7] = cxVecPtr[7];
  685. CStorePtr += 8 * unit;
  686. cxVecPtr = c2VecPtr;
  687. }
  688. for (int i = 0; i < eSize % 8; i++) {
  689. CStorePtr[8 * i] = cxVecPtr[i];
  690. }
  691. }
  692. NNZMap -= h;
  693. }
  694. return;
  695. #undef REMAIN_E_ONE_LP_ACT_E24
  696. #undef ONE_LP_ACT_E24
  697. }
  698. #undef AVX2_SP_BLOCK4
  699. #undef AVX2_SPARSE_EP
  700. #undef FP32_BYTES
  701. #undef EMULATED_AVX2_FMA
  702. #undef MIN_MAX_VEC
  703. #undef ONE_H_STORE_E24
  704. #undef TRANSPOSE_4x4_WITH_STORE
  705. #undef TRANSPOSE_4x24_WITH_STORE
  706. #undef REMAIN_TRANSPOSE_4x24_WITH_STORE