PackedFunction.cpp 27 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569
  1. //
  2. // PackedFunction.cpp
  3. // MNN
  4. //
  5. // Created by MNN on b'2021/07/05'.
  6. // Copyright © 2018, Alibaba Group Holding Limited
  7. //
  8. #include <float.h>
  9. #include <string.h>
  10. #include <algorithm>
  11. #include <limits>
  12. #include <vector>
  13. #include "FunctionSummary.hpp"
  14. #include "core/Macro.h"
  15. #include "backend/cpu/CPUPool.hpp"
  16. #include "backend/cpu/BinaryUtils.hpp"
  17. #include "Vec8.hpp"
  18. #define PACK_UNIT 8
  19. extern "C" {
  20. void _AVX_MNNCopyC4WithStride(const float* source, float* dest, size_t srcStride, size_t dstStride, size_t count);
  21. void _AVX_MNNAddC4WithStride(const float* source, float* dest, size_t srcStride, size_t dstStride, size_t count);
  22. void _AVX_MNNScaleAndAddBias(float* dst, const float* src, const float* bias, const float* alpha, size_t planeNumber, size_t biasNumber);
  23. void _AVX_MNNDeconvRunForUnitDepthWise(const float* dst, float* src, const float* weight, size_t fw, size_t fh,
  24. size_t weight_y_step, size_t dilateX_step, size_t dilateY_step);
  25. void _AVX_MNNDeconvRunForLineDepthwise(const float* dst, float* src, const float* weight, size_t width, size_t src_w_setup,
  26. size_t fw, size_t fh, size_t dilateX_step, size_t dilateY_step);
  27. void _AVX_MNNGridSampleInterp(float* outputPtr, const float* inputPtr, const float* cordPtr, size_t inH, size_t inW, size_t outW, bool sampleMode, bool padMode);
  28. void _AVX_MNNStrassenMergeCFunction(float* c11, float* c12, float* c21, float* c22, float* xAddr, size_t cStride, size_t eSub, size_t hSub);
  29. void _AVX_MNNConvRunForUnitDepthWise(float* dst, const float* src, const float* weight, size_t fw, size_t fh,
  30. size_t weight_y_step, size_t dilateX_step, size_t dilateY_step);
  31. void _AVX_MNNMultiAndDestTransformCommon23(float **cacheLine, const float *weigth, float *dest, int cacheLineSize, int ow, const float* bias, const float* parameter);
  32. void _AVX_MNNSourceTransformCommonF23(const float *source, float *dest, int unit, int iw, int pad, int su, int eu);
  33. void _AVX_MNNConvDwF23MulTransUnit(float **cacheLine, const float *weigth, float *dest, size_t ow, const float* bias, const float* parameter);
  34. void _AVX_MNNMatrixAdd(float* C, const float* A, const float* B, size_t widthC4, size_t cStride, size_t aStride,
  35. size_t bStride, size_t height);
  36. void _AVX_MNNMatrixSub(float* C, const float* A, const float* B, size_t widthC4, size_t cStride, size_t aStride,
  37. size_t bStride, size_t height);
  38. void _AVX_MNNStrassenMergeCFunction(float* c11, float* c12, float* c21, float* c22, float* xAddr, size_t cStride,
  39. size_t length, size_t hSub);
  40. void _AVX_MNNConvRunForLineDepthwise(float* dst, const float* src, const float* weight, size_t width, size_t src_w_setup,
  41. size_t fw, size_t fh, size_t dilateX_step, size_t dilateY_step, size_t height,
  42. size_t srcHStep, size_t dstHStep);
  43. void _AVX_MNNAxByClampBroadcastUnit(float* C, const float* A, const float* B, size_t width, size_t cStride, size_t aStride, size_t height, const float* parameters);
  44. }
  45. void _AVX_MNNCopyC4WithStride(const float* source, float* dest, size_t srcStride, size_t dstStride, size_t count) {
  46. for (int i = 0; i < count; ++i) {
  47. auto s = source + i * srcStride;
  48. auto d = dest + i * dstStride;
  49. _mm256_storeu_ps(d, _mm256_loadu_ps(s));
  50. }
  51. }
  52. void _AVX_MNNAddC4WithStride(const float* source, float* dest, size_t srcStride, size_t dstStride, size_t count) {
  53. for (int i = 0; i < count; ++i) {
  54. auto s = source + i * srcStride;
  55. auto d = dest + i * dstStride;
  56. _mm256_storeu_ps(d, _mm256_add_ps(_mm256_loadu_ps(s), _mm256_loadu_ps(d)));
  57. }
  58. }
  59. void _AVX_MNNReluWithSlopeChannel(float* dst, const float* src, const float* slope, size_t sizeQuad, size_t depthQuad) {
  60. auto zero = _mm_set1_ps(0.0f);
  61. auto zero2 = _mm256_set1_ps(0.0f);
  62. int sizeC8 = sizeQuad;
  63. for (int j = 0; j < depthQuad; j++) {
  64. auto slopeZ = _mm256_loadu_ps(slope + PACK_UNIT * j);
  65. const float* srcZ = src + PACK_UNIT * j * sizeQuad;
  66. float* dstZ = dst + PACK_UNIT * j * sizeQuad;
  67. for (int i = 0; i < sizeC8; i++) {
  68. auto src = _mm256_loadu_ps(srcZ);
  69. auto mask0 = _mm256_cmp_ps(src, zero2, 0x01);
  70. auto mask1 = _mm256_cmp_ps(src, zero2, 0x0D);
  71. auto other = _mm256_mul_ps(src, slopeZ);
  72. _mm256_storeu_ps(dstZ, _mm256_add_ps(_mm256_and_ps(other, mask0), _mm256_and_ps(src, mask1)));
  73. srcZ += PACK_UNIT;
  74. dstZ += PACK_UNIT;
  75. }
  76. }
  77. }
  78. void _AVX_MNNAxByClampBroadcastUnit(float* C, const float* A, const float* B, size_t width, size_t cStride, size_t aStride, size_t height, const float* parameters) {
  79. auto minF = _mm256_broadcast_ss(parameters + 2);
  80. auto maxF = _mm256_broadcast_ss(parameters + 3);
  81. for (int y = 0; y < height; ++y) {
  82. auto a = A + aStride * y;
  83. auto b = B + PACK_UNIT * y;
  84. auto bv = _mm256_loadu_ps(b);
  85. auto c = C + cStride * y;
  86. for (int x = 0; x < width; ++x) {
  87. auto av = _mm256_loadu_ps(a);
  88. auto cv = _mm256_add_ps(av, bv);
  89. cv = _mm256_min_ps(cv, maxF);
  90. cv = _mm256_max_ps(cv, minF);
  91. _mm256_storeu_ps(c, cv);
  92. a += PACK_UNIT;
  93. c += PACK_UNIT;
  94. }
  95. }
  96. }
  97. void _AVX_MNNConvRunForUnitDepthWise(float* dst, const float* src, const float* weight, size_t fw, size_t fh,
  98. size_t weight_y_step, size_t dilateX_step, size_t dilateY_step) {
  99. int fx, fy;
  100. __m256 dstValue = _mm256_setzero_ps();
  101. const float* src_z = src;
  102. const float* weight_z = weight;
  103. for (fy = 0; fy < fh; ++fy) {
  104. const float* src_y = src_z + fy * dilateY_step;
  105. const float* weight_y = weight_z + fy * weight_y_step;
  106. for (fx = 0; fx < fw; ++fx) {
  107. const float* weight_x = weight_y + PACK_UNIT * fx;
  108. const float* src_x = src_y + fx * dilateX_step;
  109. dstValue = _mm256_add_ps(dstValue, _mm256_mul_ps(_mm256_loadu_ps(src_x), _mm256_loadu_ps(weight_x)));
  110. }
  111. }
  112. _mm256_storeu_ps(dst, dstValue);
  113. }
  114. void _AVX_MNNConvRunForLineDepthwise(float* dst, const float* src, const float* weight, size_t width, size_t src_w_setup,
  115. size_t fw, size_t fh, size_t dilateX_step, size_t dilateY_step, size_t height,
  116. size_t srcHStep, size_t dstHStep) {
  117. int dx, fx, fy;
  118. const int unit = 4;
  119. int widthUnit = width / unit;
  120. int widthRemain = width - widthUnit * unit;
  121. const float* weight_z = weight;
  122. for (int y = 0; y < height; ++y) {
  123. auto srcY = src + y * srcHStep;
  124. auto dstY = dst + y * dstHStep;
  125. for (dx = 0; dx < widthUnit; ++dx) {
  126. auto dstValue0 = _mm256_setzero_ps();
  127. auto dstValue1 = _mm256_setzero_ps();
  128. auto dstValue2 = _mm256_setzero_ps();
  129. auto dstValue3 = _mm256_setzero_ps();
  130. for (fy = 0; fy < fh; ++fy) {
  131. const float* src_y = srcY + fy * dilateY_step;
  132. const float* weight_y = weight_z + fy * fw * PACK_UNIT;
  133. for (fx = 0; fx < fw; ++fx) {
  134. const float* src_x = src_y + fx * dilateX_step;
  135. const float* weight_x = weight_y + PACK_UNIT * fx;
  136. auto weightValue = _mm256_loadu_ps(weight_x);
  137. dstValue0 = _mm256_add_ps(dstValue0, _mm256_mul_ps(_mm256_loadu_ps(src_x + 0 * src_w_setup), weightValue));
  138. dstValue1 = _mm256_add_ps(dstValue1, _mm256_mul_ps(_mm256_loadu_ps(src_x + 1 * src_w_setup), weightValue));
  139. dstValue2 = _mm256_add_ps(dstValue2, _mm256_mul_ps(_mm256_loadu_ps(src_x + 2 * src_w_setup), weightValue));
  140. dstValue3 = _mm256_add_ps(dstValue3, _mm256_mul_ps(_mm256_loadu_ps(src_x + 3 * src_w_setup), weightValue));
  141. }
  142. }
  143. _mm256_storeu_ps(dstY + PACK_UNIT * 0, dstValue0);
  144. _mm256_storeu_ps(dstY + PACK_UNIT * 1, dstValue1);
  145. _mm256_storeu_ps(dstY + PACK_UNIT * 2, dstValue2);
  146. _mm256_storeu_ps(dstY + PACK_UNIT * 3, dstValue3);
  147. dstY += PACK_UNIT * unit;
  148. srcY += unit * src_w_setup;
  149. }
  150. for (dx = 0; dx < widthRemain; ++dx) {
  151. float* dst_x = dstY + dx * PACK_UNIT;
  152. auto dstValue = _mm256_setzero_ps();
  153. const float* src_z = srcY + src_w_setup * dx;
  154. const float* weight_z = weight;
  155. for (fy = 0; fy < fh; ++fy) {
  156. const float* src_y = src_z + fy * dilateY_step;
  157. const float* weight_y = weight_z + fy * fw * PACK_UNIT;
  158. for (fx = 0; fx < fw; ++fx) {
  159. const float* weight_x = weight_y + PACK_UNIT * fx;
  160. const float* src_x = src_y + fx * dilateX_step;
  161. dstValue = _mm256_add_ps(dstValue, _mm256_mul_ps(_mm256_loadu_ps(src_x), _mm256_loadu_ps(weight_x)));
  162. }
  163. }
  164. _mm256_storeu_ps(dst_x, dstValue);
  165. }
  166. }
  167. }
  168. static MNNBinaryExecute _AVX2_MNNSelectBinaryFunctionForFloat(int opType) {
  169. auto vecF = MNN::selectVector<Vec8, 8>(opType);
  170. if (nullptr != vecF) {
  171. return vecF;
  172. }
  173. return MNN::MNNGetCoreFunctions()->MNNSelectBinaryFunctionForFloat(opType);
  174. }
  175. static void _8BitcopyWithStrideC4(uint8_t* dstO, const uint8_t* srcO, int size, int stride, int ds) {
  176. auto src = (float*)srcO;
  177. auto dst = (float*)dstO;
  178. for (int i=0; i<size; ++i) {
  179. _mm256_storeu_ps(dst, _mm256_loadu_ps(src));
  180. src+= (8 * stride);
  181. dst+= (8 * ds);
  182. }
  183. }
  184. static MNNCopyWithStride _selectBlit(int bytesC4) {
  185. if (32 == bytesC4) {
  186. return _8BitcopyWithStrideC4;
  187. }
  188. return nullptr;
  189. }
  190. void _AVX_MNNScaleAndAddBias(float* dst, const float* src, const float* bias, const float* alpha, size_t planeNumber,
  191. size_t biasNumber) {
  192. for (int z = 0; z < biasNumber; ++z) {
  193. float* dstZ = dst + planeNumber * PACK_UNIT * z;
  194. const float* srcZ = src + planeNumber * PACK_UNIT * z;
  195. auto biasZ = Vec8::load(bias + PACK_UNIT * z);
  196. auto alphaZ = Vec8::load(alpha + PACK_UNIT * z);
  197. for (int p = 0; p < planeNumber; ++p) {
  198. float* dstX = dstZ + PACK_UNIT * p;
  199. const float* srcX = srcZ + PACK_UNIT * p;
  200. Vec8::save(dstX, (Vec8::load(srcX) * alphaZ) + biasZ);
  201. }
  202. }
  203. }
  204. void _AVX_MNNDeconvRunForUnitDepthWise(const float* dst, float* src, const float* weight, size_t fw, size_t fh,
  205. size_t weight_y_step, size_t dilateX_step, size_t dilateY_step) {
  206. int fx, fy;
  207. float* src_z = src;
  208. const float* weight_z = weight;
  209. Vec8 dstV = Vec8::load(dst);
  210. for (fy = 0; fy < fh; ++fy) {
  211. float* src_y = src_z + fy * dilateY_step;
  212. const float* weight_y = weight_z + fy * weight_y_step;
  213. for (fx = 0; fx < fw; ++fx) {
  214. Vec8 weight_x = Vec8::load(weight_y + PACK_UNIT * fx);
  215. Vec8 src_x = Vec8::load(src_y + fx * dilateX_step);
  216. Vec8::save(src_y + fx * dilateX_step, src_x + weight_x * dstV);
  217. }
  218. }
  219. }
  220. void _AVX_MNNDeconvRunForLineDepthwise(const float* dst, float* src, const float* weight, size_t width, size_t src_w_setup,
  221. size_t fw, size_t fh, size_t dilateX_step, size_t dilateY_step) {
  222. int dx;
  223. for (dx = 0; dx < width; ++dx) {
  224. const float* dst_x = dst + dx * PACK_UNIT;
  225. float* src_dx = src + src_w_setup * dx;
  226. _AVX_MNNDeconvRunForUnitDepthWise(dst_x, src_dx, weight, fw, fh, fw * PACK_UNIT, dilateX_step, dilateY_step);
  227. }
  228. }
  229. static __m256 MNNGridSampleLoadSample(int h, int w, const float *buffer, int height, int width, bool padMode) {
  230. if (h < 0 || h >= height || w < 0 || w >= width) {
  231. if(padMode == true) { //padMode == BorderMode_ZEROS
  232. return _mm256_setzero_ps();
  233. }
  234. // Clearly, CLAMP is the right way to go for GridSamplePaddingMode_BORDER
  235. // For GridSamplePaddingMode_REFLECTION, since we have reflected the values into (-1, 1),
  236. // the leftover reflections degrade to GridSamplePaddingMode_BORDER
  237. h = h < 0 ? 0 : ( h > (height - 1) ? (height - 1) : h);
  238. w = w < 0 ? 0 : ( w > (width - 1) ? (width - 1) : w);
  239. }
  240. return _mm256_loadu_ps(buffer + h * width * PACK_UNIT + w * PACK_UNIT);
  241. }
  242. void _AVX_MNNGridSampleInterp(float* outputPtr, const float* inputPtr, const float* cordPtr, size_t inH, size_t inW, size_t outW, bool sampleMode, bool padMode) {
  243. for (auto ow = 0; ow < outW; ++ow) {
  244. auto w = cordPtr[2 * ow + 0];
  245. auto h = cordPtr[2 * ow + 1];
  246. __m256 interp;
  247. if (sampleMode == true) { //sampleMode == SampleMode_NEAREST
  248. int nh = ::floor(h + 0.5f);
  249. int nw = ::floor(w + 0.5f);
  250. interp = MNNGridSampleLoadSample(nh, nw, inputPtr, inH, inW, padMode);
  251. } else { //sampleMode == GridSampleMode_BILINEAR
  252. int w0_h = ::floor(h);
  253. int w0_w = ::floor(w);
  254. int w1_h = ::ceil(h);
  255. int w1_w = ::ceil(w);
  256. auto oneV = _mm256_set1_ps(1.0f);
  257. __m256 i00 = MNNGridSampleLoadSample(w0_h, w0_w, inputPtr, inH, inW, padMode);
  258. __m256 i01 = MNNGridSampleLoadSample(w0_h, w1_w, inputPtr, inH, inW, padMode);
  259. __m256 i10 = MNNGridSampleLoadSample(w1_h, w0_w, inputPtr, inH, inW, padMode);
  260. __m256 i11 = MNNGridSampleLoadSample(w1_h, w1_w, inputPtr, inH, inW, padMode);
  261. auto f0 = _mm256_set1_ps((float)w1_w - w);
  262. auto f1 = _mm256_sub_ps(oneV, f0);
  263. auto h0 = _mm256_set1_ps((float)w1_h - h);
  264. auto h1 = _mm256_sub_ps(oneV, h0);
  265. __m256 i0 = _mm256_add_ps(_mm256_mul_ps(i00, f0), _mm256_mul_ps(i01, f1));
  266. __m256 i1 = _mm256_add_ps(_mm256_mul_ps(i10, f0), _mm256_mul_ps(i11, f1));
  267. interp = _mm256_add_ps(_mm256_mul_ps(i0, h0), _mm256_mul_ps(i1, h1));
  268. }
  269. _mm256_storeu_ps(outputPtr + PACK_UNIT * ow, interp);
  270. }
  271. }
  272. void _AVX_MNNMatrixAdd(float* C, const float* A, const float* B, size_t widthC4, size_t cStride, size_t aStride,
  273. size_t bStride, size_t height) {
  274. for (int y = 0; y < height; ++y) {
  275. auto a = A + aStride * y;
  276. auto b = B + bStride * y;
  277. auto c = C + cStride * y;
  278. for (int x = 0; x < widthC4; ++x) {
  279. _mm256_storeu_ps(c + PACK_UNIT * x, _mm256_add_ps(_mm256_loadu_ps(b + PACK_UNIT * x), _mm256_loadu_ps(a + PACK_UNIT * x)));
  280. }
  281. }
  282. }
  283. void _AVX_MNNStrassenMergeCFunction(float* c11, float* c12, float* c21, float* c22, float* xAddr, size_t cStride, size_t eSub, size_t hSub) {
  284. const int unit = PACK_UNIT;
  285. for (int y=0; y<hSub; ++y) {
  286. auto c11Y = c11 + y * cStride;
  287. auto c12Y = c12 + y * cStride;
  288. auto c22Y = c22 + y * cStride;
  289. auto c21Y = c21 + y * cStride;
  290. auto xY = xAddr + y * eSub * unit;
  291. for (int x=0; x<eSub; ++x) {
  292. auto xv = _mm256_loadu_ps(xY + unit*x);
  293. auto c21v = _mm256_loadu_ps(c21Y + unit*x);
  294. auto c11v = _mm256_loadu_ps(c11Y + unit*x);
  295. auto c22v = _mm256_loadu_ps(c22Y + unit*x);
  296. auto c12v = _mm256_loadu_ps(c12Y + unit*x);
  297. c12v = _mm256_add_ps(c12v, xv);
  298. c21v = _mm256_add_ps(c12v, c21v);
  299. c12v = _mm256_add_ps(c22v, c12v);
  300. c22v = _mm256_add_ps(c22v, c21v);
  301. c12v = _mm256_add_ps(c11v, c12v);
  302. _mm256_storeu_ps(c12Y + unit*x, c12v);
  303. _mm256_storeu_ps(c22Y + unit*x, c22v);
  304. _mm256_storeu_ps(c21Y + unit*x, c21v);
  305. }
  306. }
  307. }
  308. void _AVX_MNNMatrixSub(float* C, const float* A, const float* B, size_t widthC4, size_t cStride, size_t aStride,
  309. size_t bStride, size_t height) {
  310. for (int y = 0; y < height; ++y) {
  311. auto a = A + aStride * y;
  312. auto b = B + bStride * y;
  313. auto c = C + cStride * y;
  314. for (int x = 0; x < widthC4; ++x) {
  315. _mm256_storeu_ps(c + PACK_UNIT * x, _mm256_sub_ps(_mm256_loadu_ps(a + PACK_UNIT * x), _mm256_loadu_ps(b + PACK_UNIT * x)));
  316. }
  317. }
  318. }
  319. void _AVX_MNNMultiAndDestTransformCommon23(float **cacheLine, const float *weigth, float *dest, int cacheLineSize, int ow, const float* bias, const float* parameter) {
  320. int unit = ow / 2;
  321. MNN_ASSERT(cacheLineSize >= 1);
  322. auto biasF = Vec8::load(bias);
  323. auto minF = Vec8(parameter[2]);
  324. auto maxF = Vec8(parameter[3]);
  325. auto SRC_TILE_UNIT = 4 * PACK_UNIT;
  326. auto DST_TILE_UNIT = 2 * PACK_UNIT;
  327. for (int x = 0; x < unit; ++x) {
  328. auto offset = SRC_TILE_UNIT * x;
  329. int i = 0;
  330. Vec8 m0 = Vec8::load(weigth + i * SRC_TILE_UNIT + PACK_UNIT * 0) * Vec8::load(cacheLine[i] + offset + PACK_UNIT * 0);
  331. Vec8 m1 = Vec8::load(weigth + i * SRC_TILE_UNIT + PACK_UNIT * 1) * Vec8::load(cacheLine[i] + offset + PACK_UNIT * 1);
  332. Vec8 m2 = Vec8::load(weigth + i * SRC_TILE_UNIT + PACK_UNIT * 2) * Vec8::load(cacheLine[i] + offset + PACK_UNIT * 2);
  333. Vec8 m3 = Vec8::load(weigth + i * SRC_TILE_UNIT + PACK_UNIT * 3) * Vec8::load(cacheLine[i] + offset + PACK_UNIT * 3);
  334. for (i = 1; i < cacheLineSize; ++i) {
  335. m0 = m0 + Vec8::load(weigth + i * SRC_TILE_UNIT + PACK_UNIT * 0) * Vec8::load(cacheLine[i] + offset + PACK_UNIT * 0);
  336. m1 = m1 + Vec8::load(weigth + i * SRC_TILE_UNIT + PACK_UNIT * 1) * Vec8::load(cacheLine[i] + offset + PACK_UNIT * 1);
  337. m2 = m2 + Vec8::load(weigth + i * SRC_TILE_UNIT + PACK_UNIT * 2) * Vec8::load(cacheLine[i] + offset + PACK_UNIT * 2);
  338. m3 = m3 + Vec8::load(weigth + i * SRC_TILE_UNIT + PACK_UNIT * 3) * Vec8::load(cacheLine[i] + offset + PACK_UNIT * 3);
  339. }
  340. auto o0 = m0 + m1 + m2 + biasF;
  341. auto o1 = m1 - m2 + m3 + biasF;
  342. o0 = Vec8::min(maxF, o0);
  343. o1 = Vec8::min(maxF, o1);
  344. o0 = Vec8::max(minF, o0);
  345. o1 = Vec8::max(minF, o1);
  346. Vec8::save(dest + DST_TILE_UNIT * x + 0 * PACK_UNIT, o0);
  347. Vec8::save(dest + DST_TILE_UNIT * x + 1 * PACK_UNIT, o1);
  348. }
  349. if (unit * 2 < ow) {
  350. auto offset = SRC_TILE_UNIT * unit;
  351. int i = 0;
  352. Vec8 m0 = Vec8::load(weigth + i * SRC_TILE_UNIT + PACK_UNIT * 0) * Vec8::load(cacheLine[i] + offset + PACK_UNIT * 0);
  353. Vec8 m1 = Vec8::load(weigth + i * SRC_TILE_UNIT + PACK_UNIT * 1) * Vec8::load(cacheLine[i] + offset + PACK_UNIT * 1);
  354. Vec8 m2 = Vec8::load(weigth + i * SRC_TILE_UNIT + PACK_UNIT * 2) * Vec8::load(cacheLine[i] + offset + PACK_UNIT * 2);
  355. for (i = 1; i < cacheLineSize; ++i) {
  356. m0 = m0 + Vec8::load(weigth + i * SRC_TILE_UNIT + PACK_UNIT * 0) * Vec8::load(cacheLine[i] + offset + PACK_UNIT * 0);
  357. m1 = m1 + Vec8::load(weigth + i * SRC_TILE_UNIT + PACK_UNIT * 1) * Vec8::load(cacheLine[i] + offset + PACK_UNIT * 1);
  358. m2 = m2 + Vec8::load(weigth + i * SRC_TILE_UNIT + PACK_UNIT * 2) * Vec8::load(cacheLine[i] + offset + PACK_UNIT * 2);
  359. }
  360. auto o0 = m0 + m1 + m2 + biasF;
  361. o0 = Vec8::min(maxF, o0);
  362. o0 = Vec8::max(minF, o0);
  363. Vec8::save(dest + DST_TILE_UNIT * unit, o0);
  364. }
  365. }
  366. static void _AVX_MNNConvDwF23SourceTransUnit(const float *source, float *dest, size_t unit) {
  367. if (unit <= 0) {
  368. return;
  369. }
  370. Vec8 v0 = Vec8::load(source + PACK_UNIT * 0);
  371. Vec8 v1 = Vec8::load(source + PACK_UNIT * 1);
  372. Vec8 v2;
  373. Vec8 v3;
  374. source += 2 * PACK_UNIT;
  375. for (int x = 0; x < unit; ++x) {
  376. v2 = Vec8::load(source + 0 * PACK_UNIT);
  377. v3 = Vec8::load(source + 1 * PACK_UNIT);
  378. auto m0 = v0 - v2;
  379. auto m1 = v1 + v2;
  380. auto m2 = v2 - v1;
  381. auto m3 = v3 - v1;
  382. Vec8::save(dest + PACK_UNIT * 0, m0);
  383. Vec8::save(dest + PACK_UNIT * 1, m1);
  384. Vec8::save(dest + PACK_UNIT * 2, m2);
  385. Vec8::save(dest + PACK_UNIT * 3, m3);
  386. source += (2 * PACK_UNIT);
  387. dest += (4 * PACK_UNIT);
  388. v0 = v2;
  389. v1 = v3;
  390. }
  391. }
  392. void _AVX_MNNSourceTransformCommonF23(const float *source, float *dest, int unit, int iw, int pad, int su, int eu) {
  393. for (int x = 0; x < su; ++x) {
  394. auto dstX = dest + 4 * PACK_UNIT * x;
  395. auto sx = x * 2 - (int)pad;
  396. auto ex = sx + 4;
  397. auto clampSx = std::max(sx, 0);
  398. auto clampEx = std::min(ex, (int)iw);
  399. Vec8 v[4] = {0.0f, 0.0f, 0.0f, 0.0f};
  400. for (int i = clampSx; i < clampEx; ++i) {
  401. v[i - sx] = Vec8::load(source + 8 * i);
  402. }
  403. auto m0 = v[0] - v[2];
  404. auto m1 = v[1] + v[2];
  405. auto m2 = v[2] - v[1];
  406. auto m3 = v[3] - v[1];
  407. Vec8::save(dstX + PACK_UNIT * 0, m0);
  408. Vec8::save(dstX + PACK_UNIT * 1, m1);
  409. Vec8::save(dstX + PACK_UNIT * 2, m2);
  410. Vec8::save(dstX + PACK_UNIT * 3, m3);
  411. }
  412. _AVX_MNNConvDwF23SourceTransUnit(source + PACK_UNIT * (su * 2 - pad), dest + PACK_UNIT * 4 * su, eu - su);
  413. for (int x = eu; x < unit; ++x) {
  414. auto dstX = dest + PACK_UNIT * 4 * x;
  415. auto sx = x * 2 - (int)pad;
  416. auto ex = sx + 4;
  417. auto clampSx = std::max(sx, 0);
  418. auto clampEx = std::min(ex, (int)iw);
  419. Vec8 v[4] = {0.0f, 0.0f, 0.0f, 0.0f};
  420. for (int i = clampSx; i < clampEx; ++i) {
  421. v[i - sx] = Vec8::load(source + PACK_UNIT * i);
  422. }
  423. auto m0 = v[0] - v[2];
  424. auto m1 = v[1] + v[2];
  425. auto m2 = v[2] - v[1];
  426. auto m3 = v[3] - v[1];
  427. Vec8::save(dstX + PACK_UNIT * 0, m0);
  428. Vec8::save(dstX + PACK_UNIT * 1, m1);
  429. Vec8::save(dstX + PACK_UNIT * 2, m2);
  430. Vec8::save(dstX + PACK_UNIT * 3, m3);
  431. }
  432. }
  433. void _AVX_MNNConvDwF23MulTransUnit(float **cacheLine, const float *weigth, float *dest, size_t ow, const float* bias, const float* parameter) {
  434. int unit = ow / 2;
  435. auto SRC_TILE_UNIT = 4 * PACK_UNIT;
  436. auto DST_TILE_UNIT = 2 * PACK_UNIT;
  437. auto w00 = Vec8::load(weigth + 0 * SRC_TILE_UNIT + PACK_UNIT * 0);
  438. auto w01 = Vec8::load(weigth + 0 * SRC_TILE_UNIT + PACK_UNIT * 1);
  439. auto w02 = Vec8::load(weigth + 0 * SRC_TILE_UNIT + PACK_UNIT * 2);
  440. auto w03 = Vec8::load(weigth + 0 * SRC_TILE_UNIT + PACK_UNIT * 3);
  441. auto w10 = Vec8::load(weigth + 1 * SRC_TILE_UNIT + PACK_UNIT * 0);
  442. auto w11 = Vec8::load(weigth + 1 * SRC_TILE_UNIT + PACK_UNIT * 1);
  443. auto w12 = Vec8::load(weigth + 1 * SRC_TILE_UNIT + PACK_UNIT * 2);
  444. auto w13 = Vec8::load(weigth + 1 * SRC_TILE_UNIT + PACK_UNIT * 3);
  445. auto w20 = Vec8::load(weigth + 2 * SRC_TILE_UNIT + PACK_UNIT * 0);
  446. auto w21 = Vec8::load(weigth + 2 * SRC_TILE_UNIT + PACK_UNIT * 1);
  447. auto w22 = Vec8::load(weigth + 2 * SRC_TILE_UNIT + PACK_UNIT * 2);
  448. auto w23 = Vec8::load(weigth + 2 * SRC_TILE_UNIT + PACK_UNIT * 3);
  449. auto biasF = Vec8::load(bias);
  450. auto minF = Vec8(parameter[2]);
  451. auto maxF = Vec8(parameter[3]);
  452. for (int x = 0; x < unit; ++x) {
  453. auto offset = PACK_UNIT * 4 * x;
  454. int i = 0;
  455. Vec8 m0 = w00 * Vec8::load(cacheLine[0] + offset + PACK_UNIT * 0);
  456. Vec8 m1 = w01 * Vec8::load(cacheLine[0] + offset + PACK_UNIT * 1);
  457. Vec8 m2 = w02 * Vec8::load(cacheLine[0] + offset + PACK_UNIT * 2);
  458. Vec8 m3 = w03 * Vec8::load(cacheLine[0] + offset + PACK_UNIT * 3);
  459. m0 = m0 + w10 * Vec8::load(cacheLine[1] + offset + PACK_UNIT * 0);
  460. m1 = m1 + w11 * Vec8::load(cacheLine[1] + offset + PACK_UNIT * 1);
  461. m2 = m2 + w12 * Vec8::load(cacheLine[1] + offset + PACK_UNIT * 2);
  462. m3 = m3 + w13 * Vec8::load(cacheLine[1] + offset + PACK_UNIT * 3);
  463. m0 = m0 + w20 * Vec8::load(cacheLine[2] + offset + PACK_UNIT * 0);
  464. m1 = m1 + w21 * Vec8::load(cacheLine[2] + offset + PACK_UNIT * 1);
  465. m2 = m2 + w22 * Vec8::load(cacheLine[2] + offset + PACK_UNIT * 2);
  466. m3 = m3 + w23 * Vec8::load(cacheLine[2] + offset + PACK_UNIT * 3);
  467. auto o0 = m0 + m1 + m2 + biasF;
  468. auto o1 = m1 - m2 + m3 + biasF;
  469. o0 = Vec8::min(maxF, o0);
  470. o1 = Vec8::min(maxF, o1);
  471. o0 = Vec8::max(minF, o0);
  472. o1 = Vec8::max(minF, o1);
  473. Vec8::save(dest + DST_TILE_UNIT * x + 0 * PACK_UNIT, o0);
  474. Vec8::save(dest + DST_TILE_UNIT * x + 1 * PACK_UNIT, o1);
  475. }
  476. if (unit * 2 < ow) {
  477. auto offset = PACK_UNIT * 4 * unit;
  478. Vec8 m0 = w00 * Vec8::load(cacheLine[0] + offset + PACK_UNIT * 0);
  479. Vec8 m1 = w01 * Vec8::load(cacheLine[0] + offset + PACK_UNIT * 1);
  480. Vec8 m2 = w02 * Vec8::load(cacheLine[0] + offset + PACK_UNIT * 2);
  481. m0 = m0 + w10 * Vec8::load(cacheLine[1] + offset + PACK_UNIT * 0);
  482. m1 = m1 + w11 * Vec8::load(cacheLine[1] + offset + PACK_UNIT * 1);
  483. m2 = m2 + w12 * Vec8::load(cacheLine[1] + offset + PACK_UNIT * 2);
  484. m0 = m0 + w20 * Vec8::load(cacheLine[2] + offset + PACK_UNIT * 0);
  485. m1 = m1 + w21 * Vec8::load(cacheLine[2] + offset + PACK_UNIT * 1);
  486. m2 = m2 + w22 * Vec8::load(cacheLine[2] + offset + PACK_UNIT * 2);
  487. auto o0 = m0 + m1 + m2 + biasF;
  488. o0 = Vec8::min(maxF, o0);
  489. o0 = Vec8::max(minF, o0);
  490. Vec8::save(dest + DST_TILE_UNIT * unit, o0);
  491. }
  492. }
  493. void _AVX_ExtraInit(void* functions) {
  494. auto coreFunction = static_cast<MNN::CoreFunctions*>(functions);
  495. coreFunction->MNNSelectBlitFunction = _selectBlit;
  496. coreFunction->MNNPoolingAvg = (decltype(coreFunction->MNNPoolingAvg))(MNN::poolingAvg<float, Vec8, 8>);
  497. // Set min value as 1 << 24
  498. coreFunction->MNNPoolingMax = (decltype(coreFunction->MNNPoolingMax))(MNN::poolingMax<float, Vec8, 8, -16777216>);
  499. coreFunction->MNNSelectBinaryFunctionForFloat = _AVX2_MNNSelectBinaryFunctionForFloat;
  500. coreFunction->MNNCopyC4WithStride = _AVX_MNNCopyC4WithStride;
  501. coreFunction->MNNAddC4WithStride = _AVX_MNNAddC4WithStride;
  502. coreFunction->MNNScaleAndAddBias = _AVX_MNNScaleAndAddBias;
  503. coreFunction->MNNMatrixAdd = _AVX_MNNMatrixAdd;
  504. coreFunction->MNNMatrixSub = _AVX_MNNMatrixSub;
  505. coreFunction->MNNConvRunForUnitDepthWise = _AVX_MNNConvRunForUnitDepthWise;
  506. coreFunction->MNNConvRunForLineDepthwise = _AVX_MNNConvRunForLineDepthwise;
  507. coreFunction->MNNAxByClampBroadcastUnit = _AVX_MNNAxByClampBroadcastUnit;
  508. coreFunction->MNNStrassenMergeCFunction = _AVX_MNNStrassenMergeCFunction;
  509. coreFunction->MNNMultiAndDestTransformCommon23 = _AVX_MNNMultiAndDestTransformCommon23;
  510. coreFunction->MNNSourceTransformCommonF23 = _AVX_MNNSourceTransformCommonF23;
  511. coreFunction->MNNConvDwF23MulTransUnit = _AVX_MNNConvDwF23MulTransUnit;
  512. coreFunction->MNNReluWithSlopeChannel = _AVX_MNNReluWithSlopeChannel;
  513. coreFunction->MNNDeconvRunForLineDepthwise = _AVX_MNNDeconvRunForLineDepthwise;
  514. coreFunction->MNNDeconvRunForUnitDepthWise = _AVX_MNNDeconvRunForUnitDepthWise;
  515. coreFunction->MNNGridSampleInterp = _AVX_MNNGridSampleInterp;
  516. // sparse conv funcs
  517. coreFunction->MNNGetSparseMatMulPackMode = _AVX_MNNGetSparseMatMulPackMode;
  518. coreFunction->MNNPackedSparseMatMulEpx1 = _AVX_MNNPackedSparseMatMulEpx1EFMA;
  519. coreFunction->MNNPackedSparseMatMulEpx4 = _AVX_MNNPackedSparseMatMulEpx4EFMA;
  520. }