ConvolutionWinograd.cpp 34 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636
  1. //
  2. // ConvolutionWinograd.cpp
  3. // MNN
  4. //
  5. // Created by MNN on 2018/08/20.
  6. // Copyright © 2018, Alibaba Group Holding Limited
  7. //
  8. #include "backend/cpu/compute/ConvolutionWinograd.hpp"
  9. #include <math.h>
  10. #include "backend/cpu/compute/CommonOptFunction.h"
  11. #include "core/Concurrency.h"
  12. #include "backend/cpu/compute/ConvOpt.h"
  13. #include "core/Macro.h"
  14. #include "core/TensorUtils.hpp"
  15. #include "math/WingoradGenerater.hpp"
  16. #include <MNN/AutoTime.hpp>
  17. #include "common/MemoryFormater.h"
  18. #ifdef MNN_USE_NEON
  19. #include <arm_neon.h>
  20. #endif
  21. #define CONVOLUTION_WINOGRAD_MAX_UNIT 8
  22. #define CONVOLUTION_WINOGRAD_MIN_UNIT 2
  23. constexpr int FULSE_THRESHHOLD_NUMERATOR = 8;
  24. constexpr int FULSE_THRESHHOLD_DENOMINATOR = 10;
  25. using namespace MNN::Math;
  26. //#define MNN_WINOGRAD_PRINT_REDUCE_RATE
  27. //#define MNN_WINO_TRANFORM_TEST_CLOSE
  28. namespace MNN {
  29. ConvolutionWinograd::ConvolutionWinograd(const Convolution2DCommon *convOp, const Tensor *input, const Tensor *output,
  30. Backend *b, const float *originWeight, size_t originWeightSize,
  31. const float *bias, size_t biasSize, int unit)
  32. : MNN::CPUConvolution(convOp, b) {
  33. auto core = static_cast<CPUBackend*>(backend())->functions();
  34. int pack = core->pack, bytes = core->bytes;
  35. mResource.reset(new Resource);
  36. mResource->backend = b;
  37. if (!mResource->copyBiasAlign(bias, biasSize)) {
  38. MNN_ERROR("Not Enough Memory\n");
  39. mValid = false;
  40. return;
  41. }
  42. MNN_ASSERT(mCommon->kernelX() == mCommon->kernelY());
  43. int threadNumber = ((CPUBackend *)backend())->threadNumber();
  44. auto kernelSize = mCommon->kernelY();
  45. WinogradGenerater generator(unit, kernelSize, 1, true);
  46. int ePack, hPack, lPack;
  47. core->MNNGetMatMulPackMode(&ePack, &lPack, &hPack);
  48. int alpha = unit + kernelSize - 1;
  49. int alpha2 = alpha * alpha;
  50. mSourceTransform = core->chooseWinoSourceTransform(alpha, alpha);
  51. mDestTransform = core->chooseWinoDestTransform(alpha, unit);
  52. mSourceTransformPack = core->chooseWinoSourceTransformPack(alpha, alpha, ePack, lPack, pack);
  53. int srcCount = input->channel();
  54. int outputCount = output->channel();
  55. auto ic4 = UP_DIV(srcCount, pack);
  56. auto oc4 = UP_DIV(outputCount, pack);
  57. mTempBuffer.reset(Tensor::createDevice<uint8_t>({threadNumber, ePack, ic4 + oc4, pack * alpha2, bytes}));
  58. // mTransformMidBuffer.reset(Tensor::createDevice<uint8_t>({threadNumber, 2, alpha2, pack, bytes}));
  59. // mGemmMidBuffer.reset(Tensor::createDevice<uint8_t>({threadNumber, ePack * UP_DIV(srcCount, lPack) * lPack, bytes}));
  60. mTransformMidBuffer.reset(Tensor::createDevice<uint8_t>({threadNumber, (1 + ic4 * ePack), alpha2, pack, bytes})); // 1 means original small buffer of alpha2 * pack.
  61. mGemmMidBuffer.reset(Tensor::createDevice<uint8_t>({threadNumber, alpha, ePack * UP_DIV(srcCount, pack) * pack, bytes}));
  62. mA = generator.A();
  63. mB = generator.B();
  64. // Transform Kernel
  65. auto G = generator.G();
  66. // replace Tensor::createDevice by Tensor::create and allocTransformWeight's alloc=true to avoid malloc by onAcquireBuffer
  67. std::shared_ptr<Tensor> sourceWeight(Tensor::create<float>(
  68. std::vector<int>{outputCount, srcCount, kernelSize, kernelSize}, (void *)originWeight, Tensor::CAFFE));
  69. auto tempWeight = generator.allocTransformWeight(sourceWeight.get(), lPack, hPack, true);
  70. auto shape = tempWeight->shape();
  71. shape.push_back(bytes);
  72. mResource->mWeight.reset(Tensor::createDevice<uint8_t>(shape));
  73. mValid = backend()->onAcquireBuffer(mResource->mWeight.get(), Backend::STATIC);
  74. if (!mValid) {
  75. return;
  76. }
  77. generator.transformWeight(tempWeight.get(), sourceWeight.get(), true);
  78. if (bytes != 4) {
  79. core->MNNFp32ToLowp(tempWeight->host<float>(), mResource->mWeight->host<int16_t>(), tempWeight->elementSize());
  80. } else {
  81. ::memcpy(mResource->mWeight->host<float>(), tempWeight->host<float>(), tempWeight->size());
  82. }
  83. mPostParameters = getPostParameters();
  84. }
  85. ConvolutionWinograd::~ConvolutionWinograd() {
  86. // Do nothing
  87. }
  88. bool ConvolutionWinograd::onClone(Backend* bn, const Op* op, Execution** dst) {
  89. if (!mValid) {
  90. return false;
  91. }
  92. if (nullptr == dst) {
  93. return true;
  94. }
  95. auto dstExe = new ConvolutionWinograd(mResource, op->main_as_Convolution2D()->common(), bn);
  96. dstExe->mA = mA;
  97. dstExe->mB = mB;
  98. dstExe->mTempBuffer.reset(Tensor::createDevice<uint8_t>(mTempBuffer->shape()));
  99. dstExe->mTransformMidBuffer.reset(Tensor::createDevice<uint8_t>(mTransformMidBuffer->shape()));
  100. dstExe->mGemmMidBuffer.reset(Tensor::createDevice<uint8_t>(mGemmMidBuffer->shape()));
  101. dstExe->mSourceTransform = mSourceTransform;
  102. dstExe->mDestTransform = mDestTransform;
  103. dstExe->mSourceTransformPack = mSourceTransformPack;
  104. dstExe->mPostParameters = mPostParameters;
  105. *dst = dstExe;
  106. return true;
  107. }
  108. ErrorCode ConvolutionWinograd::onExecute(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) {
  109. auto core = static_cast<CPUBackend*>(backend())->functions();
  110. int pack = core->pack, bytes = core->bytes;
  111. auto input = inputs[0];
  112. auto output = outputs[0];
  113. auto dstUnit = mA->length(1); // m
  114. auto srcUnit = mA->length(0); // n
  115. int ePack, lPack, hPack;
  116. core->MNNGetMatMulPackMode(&ePack, &lPack, &hPack);
  117. auto srcUnit2 = srcUnit * srcUnit;
  118. auto alphaXStride = srcUnit * ePack * pack;
  119. auto IC4alpha2Stride = srcUnit2 * ePack * pack;
  120. int ow = output->width();
  121. int oh = output->height();
  122. int iw = input->width();
  123. int ih = input->height();
  124. int ic_4 = UP_DIV(input->channel(), pack);
  125. int dc_4 = UP_DIV(output->channel(), pack);
  126. int batch = input->batch();
  127. // MNN_PRINT("%d, %d\n", srcUnit, dstUnit);
  128. int padY = mPadY;
  129. int padX = mPadX;
  130. auto wUnit = UP_DIV(ow, dstUnit); // ow / m
  131. auto hUnit = UP_DIV(oh, dstUnit); // oh / m
  132. auto totalCount = wUnit * hUnit * batch;
  133. // MNN_PRINT("ow=%d, oh=%d\n", ow, oh);
  134. int threadNumber = std::max(((CPUBackend *)backend())->threadNumber(), 1);
  135. int tileCount = UP_DIV(totalCount, ePack);
  136. int eRemain = totalCount % ePack;
  137. threadNumber = std::min(threadNumber, tileCount);
  138. std::vector<size_t> parameters(6);
  139. parameters[0] = eRemain * bytes;
  140. parameters[1] = input->channel();
  141. parameters[2] = output->channel();
  142. parameters[3] = ePack * pack * bytes;
  143. parameters[4] = 0;
  144. parameters[5] = 0;
  145. std::vector<size_t> parametersRemain = parameters;
  146. parametersRemain[3] = eRemain * pack * bytes;
  147. auto inputOrigin = input->host<uint8_t>();
  148. auto outputOrigin = output->host<uint8_t>();
  149. auto srcOrigin = inputOrigin;
  150. auto dstOrigin = outputOrigin;
  151. auto midBuffer0Bytes = srcUnit2 * pack * bytes;
  152. bool allow_x86_bf16_winograd = true;
  153. #ifdef MNN_USE_SSE
  154. allow_x86_bf16_winograd = bytes != 2; // only bf16 has length of 2 byte on x86. fp16 dosnot exist.
  155. #endif
  156. // using ElementType = int16_t;
  157. // MNN_PRINT("winograd: this:%p, n:%d, ih:%d, iw:%d, ic:%d, oh:%d, ow:%d, oc:%d, kh:%d, kw:%d, totalCount:%d, srcUnit:%d, dstUnit:%d, ePack:%d, pack:%d, bytes:%d\n",
  158. // this, batch, ih, iw, input->channel(), oh, ow, output->channel(), mCommon->kernelX(), mCommon->kernelY(), totalCount, srcUnit, dstUnit, ePack, pack, bytes);
  159. // MNN_PRINT("origin data matrix:\n");
  160. // formatMatrix((const ElementType*)srcOrigin, {ic_4, batch*ih, iw, pack});
  161. auto weight = mResource->mWeight->host<uint8_t>();
  162. auto bias = mResource->mBias->host<uint8_t>();
  163. auto tFunction = [&](int tId) {
  164. auto _srcOrigin = mTempBuffer->host<uint8_t>() + tId * mTempBuffer->stride(0);
  165. auto gemmBuffer = (mGemmMidBuffer->host<uint8_t>() + tId * mGemmMidBuffer->stride(0));
  166. auto midBuffer0 = mTransformMidBuffer->host<uint8_t>() + tId * mTransformMidBuffer->stride(0);
  167. auto midBuffer1 = midBuffer0 + midBuffer0Bytes;
  168. for (int tIndex = (int)tId; tIndex < tileCount; tIndex += threadNumber) {
  169. int xIndex = (int)tIndex * ePack;
  170. int xReamin = totalCount - xIndex;
  171. int xC = xReamin > ePack ? ePack : xReamin;
  172. const bool fuseTransformPack = (xC * FULSE_THRESHHOLD_DENOMINATOR > FULSE_THRESHHOLD_NUMERATOR * ePack) && allow_x86_bf16_winograd;
  173. // const bool fuseTransformPack = false;
  174. // Timer timer;
  175. // uint64_t durationSourceTrans1 = 0;
  176. // uint64_t durationSourceTrans2 = 0;
  177. // uint64_t durationMul = 0;
  178. // uint64_t packATime = 0;
  179. // uint64_t durationDestTrans1 = 0;
  180. // uint64_t durationDestTrans2 = 0;
  181. /*Source Transform Begin*/
  182. #ifndef MNN_WINO_TRANFORM_TEST_CLOSE
  183. {
  184. int sourceZStep = iw * ih * batch * pack;
  185. int oyBegin = xIndex / wUnit;
  186. int oxBegin = xIndex % wUnit;
  187. int oyEnd = (xIndex + xC-1) / wUnit;
  188. int remain = xC;
  189. int destSOffset = 0;
  190. if (fuseTransformPack) {
  191. for (int hbIndex=oyBegin; hbIndex <= oyEnd; ++hbIndex) {
  192. int hIndex = hbIndex % hUnit;
  193. int bIndex = hbIndex / hUnit;
  194. int step = std::min(wUnit - oxBegin, remain);
  195. int srcY = hIndex * dstUnit - padY;
  196. int ey = ALIMIN(srcY + srcUnit, ih) - srcY;
  197. int sy = ALIMAX(0, srcY) - srcY;
  198. for (int si=0; si<step; ++si) {
  199. auto wIndex = si + oxBegin;
  200. int srcX = wIndex * dstUnit - padX;
  201. int sx = ALIMAX(0, srcX) - srcX;
  202. int ex = ALIMIN(srcX + srcUnit, iw) - srcX;
  203. int count = pack * (ex - sx);
  204. auto srcStart = srcOrigin + (srcX + srcY * iw + bIndex * iw * ih) * pack * bytes;
  205. // MNN_PRINT("\nxIndex:%d, xC:%d, alphaXStride:%d, srcUnit:%d, destUnit:%d, hUnit:%d, wUnit:%d, srcY:%d, hStart:%d ,hEnd:%d, wStart:%d, wEnd:%d, i_oh:%d, i_ow:%d, srcOffset:%ld, destSOffset:%d\n",
  206. // xIndex, xC, alphaXStride, srcUnit, dstUnit, hUnit, wUnit, srcY, sy, ey, sx, ey, hIndex - oyBegin, si, (srcStart - srcOrigin)/bytes, (destSOffset)/bytes);
  207. // timer.reset();
  208. auto midBuffer1Offset = midBuffer1 + destSOffset;
  209. if (ex - sx == srcUnit && ey - sy == srcUnit) {
  210. for (int z = 0; z < ic_4; ++z) {
  211. auto srcZ = srcStart + z * sourceZStep * bytes;
  212. // Transform
  213. // MNN_PRINT("z:%d, srcOffset:%ld, destSOffset:%ld, \n", z, ((unsigned const char*)srcZ - srcOrigin)/bytes, ((unsigned const char*)midBuffer1Offset - midBuffer1)/bytes);
  214. // MNN_PRINT("winograd source sub matrix:\n");
  215. // formatMatrix((const float*)srcZ, {srcUnit, 4});
  216. for (int i = 0; i < srcUnit; ++i) { // i_Nh
  217. auto srcFloatPtr = (const float*)(srcZ + i * iw * pack * bytes);
  218. auto dstFloatPtr = (float*)(midBuffer1Offset + i * ePack * pack * bytes);
  219. mSourceTransform(srcFloatPtr, dstFloatPtr, pack, alphaXStride); // tranform srcUnit*4 elements in one time
  220. // MNN_PRINT("z:%d, 1 stage i_Nh:%d th srcOffset:%ld, destOffset:%ld, \n", z, i, ((unsigned const char*)srcFloatPtr - srcOrigin)/bytes, ((unsigned const char*)dstFloatPtr - midBuffer1)/bytes);
  221. // MNN_PRINT("winograd source sub matrix:\n");
  222. // formatMatrix(srcFloatPtr, {srcUnit, 4});
  223. }
  224. midBuffer1Offset += IC4alpha2Stride * bytes;
  225. }
  226. } else {
  227. for (int z = 0; z < ic_4; ++z) {
  228. // Extract
  229. auto srcZ = srcStart + z * sourceZStep * bytes;
  230. ::memset(midBuffer0, 0, midBuffer0Bytes);
  231. if (count > 0) {
  232. for (int yy = sy; yy < ey; ++yy) {
  233. auto dst_yy = midBuffer0 + (yy * srcUnit + sx) * pack * bytes;
  234. auto src_yy = srcZ + (iw * yy + sx) * pack * bytes;
  235. ::memcpy(dst_yy, src_yy, count * bytes);
  236. }
  237. }
  238. // Transform
  239. for (int i = 0; i < srcUnit; ++i) {
  240. auto srcFloatPtr = (const float*)(midBuffer0 + i * srcUnit * pack * bytes);
  241. auto dstFloatPtr = (float*)(midBuffer1Offset + i * ePack * pack * bytes);
  242. mSourceTransform(srcFloatPtr, dstFloatPtr, pack, alphaXStride);
  243. }
  244. midBuffer1Offset += IC4alpha2Stride * bytes;
  245. }
  246. }
  247. // durationSourceTrans1 += timer.durationInUs();
  248. destSOffset += pack * bytes;
  249. }
  250. oxBegin = 0;
  251. remain -= step;
  252. }
  253. } else {
  254. int dstZStep = xC * pack; // hUnit*wUnit * 4
  255. int unitStep = ic_4 * xC * pack; // C/4 * hUnit*wUnit * 4
  256. for (int hbIndex=oyBegin; hbIndex <= oyEnd; ++hbIndex) {
  257. int hIndex = hbIndex % hUnit;
  258. int bIndex = hbIndex / hUnit;
  259. int step = std::min(wUnit - oxBegin, remain);
  260. int srcY = hIndex * dstUnit - padY;
  261. int ey = ALIMIN(srcY + srcUnit, ih) - srcY; //h dim pack element length
  262. int sy = ALIMAX(0, srcY) - srcY; // first y element
  263. for (int si=0; si<step; ++si) {
  264. auto wIndex = si + oxBegin;
  265. int srcX = wIndex * dstUnit - padX;
  266. int sx = ALIMAX(0, srcX) - srcX;
  267. int ex = ALIMIN(srcX + srcUnit, iw) - srcX;
  268. int count = pack * (ex - sx);
  269. auto srcStart = srcOrigin + (srcX + srcY * iw + bIndex * iw * ih) * pack * bytes;
  270. // MNN_PRINT("\nxIndex:%d, xC:%d, alphaXStride:%d, srcUnit:%d, destUnit:%d, hUnit:%d, wUnit:%d, srcY:%d, hStart:%d ,hEnd:%d, wStart:%d, wEnd:%d, i_oh:%d, i_ow:%d, srcOffset:%ld, destSOffset:%d\n",
  271. // xIndex, xC, alphaXStride, srcUnit, dstUnit, hUnit, wUnit, srcY, sy, ey, sx, ey, hIndex - oyBegin, si, (srcStart - srcOrigin)/bytes, (destSOffset)/bytes);
  272. // timer.reset();
  273. auto dst_x = _srcOrigin + destSOffset;
  274. if (ex - sx == srcUnit && ey - sy == srcUnit) {
  275. for (int z = 0; z < ic_4; ++z) {
  276. auto srcZ = srcStart + z * sourceZStep * bytes;
  277. // Transform
  278. for (int i = 0; i < srcUnit; ++i) {
  279. auto srcFloatPtr = (const float*)(srcZ + i * iw * pack * bytes);
  280. auto dstFloatPtr = (float*)(midBuffer1 + i * pack * bytes);
  281. // MNN_PRINT("z:%d, 1 stage i_Nh:%d th srcOffset:%ld, destOffset:%ld, \n", z, i, ((unsigned const char*)srcFloatPtr - srcOrigin)/bytes, ((unsigned const char*)dstFloatPtr - midBuffer1)/bytes);
  282. // MNN_PRINT("winograd source sub matrix:\n");
  283. // formatMatrix(srcFloatPtr, {srcUnit, pack});
  284. mSourceTransform(srcFloatPtr, dstFloatPtr, pack, pack * srcUnit);
  285. }
  286. auto dstZ = dst_x + z * dstZStep * bytes;
  287. for (int i = 0; i < srcUnit; ++i) {
  288. auto srcFloatPtr = (const float*)(midBuffer1 + i * srcUnit * pack * bytes);
  289. auto dstFloatPtr = (float*)(dstZ + i * unitStep * bytes);
  290. mSourceTransform(srcFloatPtr, dstFloatPtr, pack,
  291. unitStep * srcUnit);
  292. }
  293. }
  294. } else {
  295. for (int z = 0; z < ic_4; ++z) {
  296. // Extract
  297. auto srcZ = srcStart + z * sourceZStep * bytes;
  298. ::memset(midBuffer0, 0, mTransformMidBuffer->stride(1));
  299. if (count > 0) {
  300. for (int yy = sy; yy < ey; ++yy) {
  301. auto dst_yy = midBuffer0 + (yy * srcUnit + sx) * pack * bytes;
  302. auto src_yy = srcZ + (iw * yy + sx) * pack * bytes;
  303. ::memcpy(dst_yy, src_yy, count * bytes);
  304. }
  305. }
  306. // Transform
  307. for (int i = 0; i < srcUnit; ++i) {
  308. auto srcFloatPtr = (const float*)(midBuffer0 + i * srcUnit * pack * bytes);
  309. auto dstFloatPtr = (float*)(midBuffer1 + i * pack * bytes);
  310. mSourceTransform(srcFloatPtr, dstFloatPtr, pack, pack * srcUnit);
  311. }
  312. auto dstZ = dst_x + z * dstZStep * bytes;
  313. for (int i = 0; i < srcUnit; ++i) {
  314. auto srcFloatPtr = (const float*)(midBuffer1 + i * srcUnit * pack * bytes);
  315. auto dstFloatPtr = (float*)(dstZ + i * unitStep * bytes);
  316. mSourceTransform(srcFloatPtr, dstFloatPtr, pack, unitStep * srcUnit);
  317. }
  318. }
  319. }
  320. // durationSourceTrans1 += timer.durationInUs();
  321. destSOffset += pack * bytes;
  322. }
  323. oxBegin = 0;
  324. remain -= step;
  325. }
  326. }
  327. }
  328. #endif
  329. auto* _dstOrigin = _srcOrigin;
  330. if (fuseTransformPack) {
  331. _dstOrigin += ePack * srcUnit2 * ic_4 * pack * bytes;
  332. if (xC != ePack) {
  333. auto midTransformPtr = midBuffer1 + xC * pack * bytes;
  334. for (int i = 0; i < ic_4 * srcUnit2; ++i) {
  335. memset(midTransformPtr, 0, (ePack - xC) * pack * bytes);
  336. midTransformPtr += ePack * pack * bytes;
  337. }
  338. }
  339. // MNN_PRINT("winograd source matrix transform 1 D*B:\n");
  340. // formatMatrix((const ElementType*)midBuffer1, {ic_4, srcUnit, srcUnit, ePack, pack});
  341. for (int iNw = 0; iNw < srcUnit; ++iNw) { // i_Nw
  342. // timer.reset();
  343. auto midTransformPtr = midBuffer1 + iNw * alphaXStride * bytes;
  344. auto unitsGemmbuffer = gemmBuffer;
  345. for (int z = 0; z < ic_4; ++z) { // ic_4
  346. mSourceTransformPack((float*)midTransformPtr, (float*)unitsGemmbuffer, ePack * pack * ic_4);
  347. unitsGemmbuffer += ePack * pack * bytes;
  348. midTransformPtr += IC4alpha2Stride * bytes;
  349. }
  350. // durationSourceTrans2 += timer.durationInUs();
  351. // timer.reset();
  352. // MNN_PRINT("winograd source matrix transform 2 BT*D*B, iNw:%d\n", iNw);
  353. // formatMatrix((const ElementType*)gemmBuffer, {srcUnit, ic_4 * pack, ePack});
  354. // Previous tranform requires xC aligned with EPack, xC should be Epack;
  355. for (int iNh = 0; iNh < srcUnit; ++iNh) { // i_Nh, gemm
  356. auto unitsGemmbuffer = gemmBuffer + iNh * ic_4 * pack * ePack * bytes;
  357. auto _dstFloatPtr = (float*)(_dstOrigin + (iNh * srcUnit + iNw) * dc_4 * pack * ePack * bytes);
  358. auto _weightFloatPtr = (const float*)(weight + (iNh * srcUnit + iNw) * mResource->mWeight->stride(0));
  359. core->MNNPackedMatMul(_dstFloatPtr, (float*)unitsGemmbuffer, _weightFloatPtr, parameters.data(), nullptr, nullptr);
  360. // MNN_PRINT("winograd MatMul result, iNh%d, iNw:%d\n", iNh, iNw);
  361. // formatMatrix((const ElementType*)_dstFloatPtr, { dc_4, ePack, pack});
  362. }
  363. // durationMul += timer.durationInUs();
  364. }
  365. } else {
  366. // MNN_PRINT("winograd source matrix after b*d*b:\n");
  367. // formatMatrix((const ElementType*)_srcOrigin, {srcUnit, srcUnit, ic_4, hUnit, wUnit, pack});
  368. /*Source Transform End*/
  369. // // Multi
  370. _dstOrigin += xC * srcUnit2 * ic_4 * pack * bytes;
  371. int32_t info[4];
  372. info[0] = 1;
  373. info[1] = xC;
  374. info[2] = xC;
  375. info[3] = 1;
  376. int32_t el[4];
  377. el[0] = xC;
  378. el[1] = parameters[1];
  379. el[2] = 0;
  380. el[3] = 0;
  381. if (xC == ePack) {
  382. for (int i = 0; i < srcUnit2; ++i) {
  383. // timer.reset();
  384. auto srcTemp = (const float*)(_srcOrigin + i * ic_4 * pack * xC * bytes);
  385. auto _dstFloatPtr = (float*)(_dstOrigin + i * dc_4 * pack * xC * bytes);
  386. auto _weightFloatPtr = (const float*)(weight + i * mResource->mWeight->stride(0));
  387. // MNN_PRINT("winograd i_n:%d, xC:%d, ePack:%d, before packA:\n", i, xC, ePack);
  388. // formatMatrix((const ElementType*)srcTemp, {ic_4, xC, pack});
  389. core->MNNPackC4ForMatMul_A((float*)gemmBuffer, &srcTemp, info, el);
  390. // packATime += timer.durationInUs();
  391. // timer.reset();
  392. // MNN_PRINT("winograd i_n:%d, after packA:\n", i);
  393. // formatMatrix((const ElementType*)gemmBuffer, {1, ic_4 * pack, xC});
  394. core->MNNPackedMatMul(_dstFloatPtr, (float*)gemmBuffer, _weightFloatPtr, parameters.data(), nullptr, nullptr);
  395. // MNN_PRINT("winograd MatMul result, iNh:%d, iNw:%d\n", i/srcUnit, i % srcUnit);
  396. // formatMatrix((const ElementType*)_dstFloatPtr, { dc_4, xC, pack});
  397. // durationMul += timer.durationInUs();
  398. }
  399. } else {
  400. for (int i = 0; i < srcUnit2; ++i) {
  401. // timer.reset();
  402. auto srcTemp = (const float*)(_srcOrigin + i * ic_4 * pack * xC * bytes);
  403. auto _dstFloatPtr = (float*)(_dstOrigin + i * dc_4 * pack * xC * bytes);
  404. auto _weightFloatPtr = (const float*)(weight + i * mResource->mWeight->stride(0));
  405. // MNN_PRINT("winograd i_n:%d, xC:%d, ePack:%d, before packA:\n", i, xC, ePack);
  406. // formatMatrix((const ElementType*)srcTemp, {ic_4, xC, pack});
  407. core->MNNPackC4ForMatMul_A((float*)gemmBuffer, &srcTemp, info, el);
  408. // packATime += timer.durationInUs();
  409. // timer.reset();
  410. // MNN_PRINT("winograd i_n:%d, after packA:\n", i);
  411. // formatMatrix((const ElementType*)gemmBuffer, {1, ic_4 * pack, xC});
  412. core->MNNPackedMatMulRemain(_dstFloatPtr, (float*)gemmBuffer, _weightFloatPtr, xC, parametersRemain.data(), nullptr, nullptr);
  413. // MNN_PRINT("winograd MatMul result, iNh:%d, iNw:%d\n", i/srcUnit, i % srcUnit);
  414. // formatMatrix((const ElementType*)_dstFloatPtr, { dc_4, xC, pack});
  415. // durationMul += timer.durationInUs();
  416. }
  417. }
  418. }
  419. #ifndef MNN_WINO_TRANFORM_TEST_CLOSE
  420. /* Dest Transform And Post Treat Begin */
  421. {
  422. int srcZStep = (fuseTransformPack ? ePack : xC) * pack;
  423. int unitStep = (fuseTransformPack ? ePack : xC) * dc_4 * pack;
  424. int dstZStep = ow * oh * pack * batch;
  425. int oyBegin = xIndex / wUnit;
  426. int oxBegin = xIndex % wUnit;
  427. int oyEnd = (xIndex + xC-1) / wUnit;
  428. int remain = xC;
  429. auto dstS = _dstOrigin;
  430. for (int hbIndex=oyBegin; hbIndex <= oyEnd; ++hbIndex) {
  431. int hIndex = hbIndex % hUnit;
  432. int bIndex = hbIndex / hUnit;
  433. int step = std::min(wUnit - oxBegin, remain);
  434. int dstY = hIndex * dstUnit;
  435. int ey = ALIMIN(dstY + dstUnit, oh) - dstY;
  436. for (int si=0; si<step; ++si) {
  437. auto wIndex = si + oxBegin;
  438. auto srcXi = dstS + pack * si * bytes;
  439. int dstX = wIndex * dstUnit;
  440. auto dstStart = dstOrigin + (dstX + dstY * ow + bIndex * ow * oh) * pack * bytes;
  441. int ex = ALIMIN(dstX + dstUnit, ow) - dstX;
  442. int count = ex * pack;
  443. if (ex == dstUnit) {
  444. for (int z = 0; z < dc_4; ++z) {
  445. auto dstZAddr = dstStart + z * dstZStep * bytes;
  446. auto srcZ = srcXi + z * srcZStep * bytes;
  447. // Transform
  448. for (int i = 0; i < srcUnit; ++i) {
  449. auto srcFloatPtr = (const float*)(srcZ + i * unitStep * bytes);
  450. auto dstFloatPtr = (float*)(midBuffer0 + i * dstUnit * pack * bytes);
  451. mDestTransform(srcFloatPtr, dstFloatPtr, srcUnit * unitStep, pack);
  452. }
  453. for (int i = 0; i < ey; ++i) {
  454. auto srcFloatPtr = (const float*)(midBuffer0 + i * pack * bytes);
  455. auto dstFloatPtr = (float*)(dstZAddr + i * pack * ow * bytes);
  456. mDestTransform(srcFloatPtr, dstFloatPtr, pack * dstUnit, pack);
  457. }
  458. }
  459. } else {
  460. for (int z = 0; z < dc_4; ++z) {
  461. auto dstZAddr = dstStart + z * dstZStep * bytes;
  462. auto srcZ = srcXi + z * srcZStep * bytes;
  463. // Transform
  464. for (int i = 0; i < srcUnit; ++i) {
  465. auto srcFloatPtr = (const float*)(srcZ + i * unitStep * bytes);
  466. auto dstFloatPtr = (float*)(midBuffer0 + i * dstUnit * pack * bytes);
  467. mDestTransform(srcFloatPtr, dstFloatPtr, srcUnit * unitStep, pack);
  468. }
  469. for (int i = 0; i < ey; ++i) {
  470. auto srcFloatPtr = (const float*)(midBuffer0 + i * pack * bytes);
  471. auto dstFloatPtr = (float*)(midBuffer1 + i * dstUnit * pack * bytes);
  472. mDestTransform(srcFloatPtr, dstFloatPtr, pack * dstUnit, pack);
  473. }
  474. for (int yy = 0; yy < ey; ++yy) {
  475. auto dstYAddr = dstZAddr + yy * pack * ow * bytes;
  476. auto srcYAddr = midBuffer1 + yy * pack * dstUnit * bytes;
  477. ::memcpy(dstYAddr, srcYAddr, count * bytes);
  478. }
  479. }
  480. }
  481. }
  482. oxBegin = 0;
  483. remain -= step;
  484. dstS += pack * step * bytes;
  485. }
  486. }
  487. #endif
  488. /*Dest Transform And Post Treat End*/
  489. // if (fuseTransformPack) {
  490. // MNN_PRINT(
  491. // "\n relayout fused:\n\tdurationSourceTrans1: %lu us \n\tdurationSourceTrans2: %lu us \n\tdurationMul: %lu us\n\ttotal: %lu us\n",
  492. // durationSourceTrans1, durationSourceTrans2, durationMul, durationSourceTrans1 + durationSourceTrans2 + durationMul);
  493. // } else {
  494. // MNN_PRINT(
  495. // "\n origin:\n\tdurationSourceTrans1+2: %lu us \n\t packA:%lu us \n\t durationMul:%lu us\n\ttotal: %lu us\n",
  496. // durationSourceTrans1, packATime, durationMul, durationSourceTrans1 + durationSourceTrans2 + durationMul + packATime);
  497. // }
  498. }
  499. };
  500. MNN_CONCURRENCY_BEGIN(tId, threadNumber) {
  501. tFunction((int)tId);
  502. }
  503. MNN_CONCURRENCY_END();
  504. MNN_CONCURRENCY_BEGIN(tId, threadNumber) {
  505. for (int dy=(int)tId; dy < dc_4; dy += threadNumber) {
  506. auto dataFloatPtr = (float*)(dstOrigin + ow * oh * batch * dy * pack * bytes);
  507. auto biasFloatPtr = (const float*)(bias + pack * dy * bytes);
  508. core->MNNAxByClampBroadcastUnit(dataFloatPtr, dataFloatPtr, biasFloatPtr, ow * oh * batch, 0, 0, 1, mPostParameters.data());
  509. }
  510. }
  511. MNN_CONCURRENCY_END();
  512. return NO_ERROR;
  513. }
  514. int ConvolutionWinograd::bestWinogradUnit(const Convolution2DCommon *common, const Tensor *inputTensor,
  515. const Tensor *outputTensor, int threadNumber, Backend* b) {
  516. auto core = static_cast<CPUBackend*>(b)->functions();
  517. int ow = outputTensor->width();
  518. int oh = outputTensor->height();
  519. int oc = outputTensor->channel();
  520. int ePack, hPack, lPack;
  521. core->MNNGetMatMulPackMode(&ePack, &lPack, &hPack);
  522. int unit2 = UP_DIV(ow * oh, ePack * threadNumber);
  523. int maxUnit = (int)::sqrtf((float)unit2);
  524. maxUnit = std::min(maxUnit, CONVOLUTION_WINOGRAD_MAX_UNIT);
  525. maxUnit = std::max(maxUnit, CONVOLUTION_WINOGRAD_MIN_UNIT);
  526. int ic = inputTensor->channel();
  527. auto kernelSize = common->kernelY();
  528. int unit = 0;
  529. float maxRate = 0.0f;
  530. float originCost = (float)ow * oh * (float)ic * oc * kernelSize * kernelSize;
  531. std::set<int> supportSu{4, 6, 8};
  532. for (int u = CONVOLUTION_WINOGRAD_MIN_UNIT; u <= maxUnit; ++u) {
  533. auto sui = u + kernelSize - 1;
  534. auto su = (float)sui;
  535. if (supportSu.find(sui) == supportSu.end()) {
  536. continue;
  537. }
  538. if (nullptr == core->chooseWinoDestTransform((int)su, u)) {
  539. continue;
  540. }
  541. /*Let F(6,3) be choosed when it can speed up from F(2,3) than 0.6*/
  542. float penalty = (su * su) / (float)(kernelSize * kernelSize) * 0.12f;
  543. float winogradCost =
  544. (2 * su * su * ic + su * su * ic * oc + (su + u) * u * oc) * (UP_DIV(ow, u) * UP_DIV(oh, u));
  545. float reduceRate = originCost / winogradCost - penalty;
  546. // MNN_PRINT("ow=%d, oh=%d, %f, %f, winograd unit:%d\n", ow, oh, winogradCost, reduceRate, u);
  547. if (reduceRate > maxRate) {
  548. maxRate = reduceRate;
  549. unit = u;
  550. }
  551. }
  552. if (maxRate < 1.0f) {
  553. return 0;
  554. }
  555. return unit;
  556. }
  557. bool ConvolutionWinograd::canUseWinograd(const Convolution2DCommon *common) {
  558. if (common->kernelY() != common->kernelX() || common->kernelY() <= 1) {
  559. return false;
  560. }
  561. if (common->dilateX() != 1 || common->dilateY() != 1) {
  562. return false;
  563. }
  564. if (common->strideX() != 1 || common->strideY() != 1) {
  565. return false;
  566. }
  567. return true;
  568. }
  569. ErrorCode ConvolutionWinograd::onResize(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) {
  570. CPUConvolution::onResize(inputs, outputs);
  571. // FUNC_PRINT(mA->length(1));
  572. bool success = backend()->onAcquireBuffer(mTempBuffer.get(), Backend::DYNAMIC);
  573. success = success && backend()->onAcquireBuffer(mGemmMidBuffer.get(), Backend::DYNAMIC);
  574. success = success && (backend()->onAcquireBuffer(mTransformMidBuffer.get(), Backend::DYNAMIC));
  575. backend()->onReleaseBuffer(mTempBuffer.get(), Backend::DYNAMIC);
  576. backend()->onReleaseBuffer(mTransformMidBuffer.get(), Backend::DYNAMIC);
  577. backend()->onReleaseBuffer(mGemmMidBuffer.get(), Backend::DYNAMIC);
  578. if (!success) {
  579. return OUT_OF_MEMORY;
  580. }
  581. return NO_ERROR;
  582. }
  583. } // namespace MNN