ConvCutlassExecution.cu 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252
  1. //
  2. // ConvCutlassExecution.cpp
  3. // MNN
  4. //
  5. // Created by MNN on 2020/08/22.
  6. // Copyright © 2018, Alibaba Group Holding Limited
  7. //
  8. #include "ConvCutlassExecution.hpp"
  9. #include "Raster.cuh"
  10. #include "ConvBaseKernel.cuh"
  11. //#define DEBUG
  12. namespace MNN {
  13. namespace CUDA {
  14. ConvCutlassExecution::Resource::Resource(Backend* bn, const MNN::Op* op) {
  15. mBackend = bn;
  16. auto runtime = static_cast<CUDABackend*>(bn)->getCUDARuntime();
  17. auto conv = op->main_as_Convolution2D();
  18. auto common = conv->common();
  19. //weight host->device
  20. const float* filterDataPtr = nullptr;
  21. int weightSize = 0;
  22. std::shared_ptr<ConvolutionCommon::Int8Common> quanCommon;
  23. ConvolutionCommon::getConvParameters(&quanCommon, bn, conv, &filterDataPtr, &weightSize);
  24. auto oc = common->outputCount();
  25. int l = weightSize / oc;
  26. int h = oc;
  27. int lp = UP_DIV(l, 8) * 8;
  28. int hp = UP_DIV(h, 8) * 8;
  29. // Reorder weight
  30. {
  31. auto tempCacheBuffer = static_cast<CUDABackend*>(bn)->getStaticBufferPool()->alloc(weightSize * sizeof(float));
  32. float* cacheWeight = (float*)((uint8_t*)tempCacheBuffer.first + tempCacheBuffer.second);
  33. runtime->memcpy(cacheWeight, filterDataPtr, weightSize * sizeof(float), MNNMemcpyHostToDevice);
  34. if(static_cast<CUDABackend*>(bn)->getPrecision() == 1) {
  35. weightTensor.reset(Tensor::createDevice<int32_t>({lp * hp}));
  36. } else {
  37. weightTensor.reset(Tensor::createDevice<int16_t>({lp * hp}));
  38. }
  39. bn->onAcquireBuffer(weightTensor.get(), Backend::STATIC);
  40. mFilter = (void *)weightTensor.get()->buffer().device;
  41. int precision = static_cast<CUDABackend*>(bn)->getPrecision();
  42. if(precision == 2) {
  43. precision == 0;
  44. }
  45. callWeightFill((const void *)cacheWeight, (void *)mFilter, l, h, lp, hp, static_cast<CUDABackend*>(bn)->getPrecision() == 1, runtime);
  46. static_cast<CUDABackend*>(bn)->getStaticBufferPool()->free(tempCacheBuffer);
  47. }
  48. // Copy Bias
  49. {
  50. if(static_cast<CUDABackend*>(bn)->useFp16()) {
  51. int biasSize = conv->bias()->size();
  52. int hp = UP_DIV(biasSize, 8) * 8;
  53. auto tempBiasStorage = static_cast<CUDABackend*>(bn)->getStaticBufferPool()->alloc(hp*sizeof(float));
  54. auto biasTemp = (float*)((uint8_t*)tempBiasStorage.first + tempBiasStorage.second);
  55. runtime->memset(biasTemp, 0, hp * sizeof(int32_t));
  56. cuda_check(cudaMemcpy(biasTemp, conv->bias()->data(), conv->bias()->size()*sizeof(float), cudaMemcpyHostToDevice));
  57. biasTensor.reset(Tensor::createDevice<int16_t>({hp}));
  58. bn->onAcquireBuffer(biasTensor.get(), Backend::STATIC);
  59. mBias = (void *)biasTensor.get()->buffer().device;
  60. callFloat2Half((const void*)biasTemp, (void*)mBias, hp, runtime);
  61. static_cast<CUDABackend*>(bn)->getStaticBufferPool()->free(tempBiasStorage);
  62. } else {
  63. int biasSize = conv->bias()->size();
  64. int hp = UP_DIV(biasSize, 8) * 8;
  65. biasTensor.reset(Tensor::createDevice<int32_t>({hp}));
  66. bn->onAcquireBuffer(biasTensor.get(), Backend::STATIC);
  67. mBias = (void *)biasTensor.get()->buffer().device;
  68. runtime->memset(mBias, 0, hp * sizeof(int32_t));
  69. cuda_check(cudaMemcpy(mBias, conv->bias()->data(), conv->bias()->size()*sizeof(float), cudaMemcpyHostToDevice));
  70. }
  71. }
  72. }
  73. ConvCutlassExecution::Resource::~Resource() {
  74. // Do nothing
  75. }
  76. ConvCutlassExecution::ConvCutlassExecution(Backend* backend, const MNN::Op* op, std::shared_ptr<Resource> res) : CutlassConvCommonExecution(backend) {
  77. mOp = op;
  78. mResource = res;
  79. auto runtime = static_cast<CUDABackend*>(backend)->getCUDARuntime();
  80. mPrecisonLevel = static_cast<CUDABackend*>(backend)->getPrecision();
  81. mFp16Infer = (mPrecisonLevel == 2);
  82. mFp32Infer = (mPrecisonLevel == 1);
  83. mFp16Fp32MixInfer = (mPrecisonLevel == 0);
  84. mBf16Infer = (mPrecisonLevel == 3);
  85. }
  86. ConvCutlassExecution::~ConvCutlassExecution() {
  87. }
  88. bool ConvCutlassExecution::onClone(Backend* bn, const Op* op, Execution** dst) {
  89. if (!mValid) {
  90. return false;
  91. }
  92. if (nullptr == dst) {
  93. return true;
  94. }
  95. auto dstExe = new ConvCutlassExecution(bn, op, mResource);
  96. *dst = dstExe;
  97. return true;
  98. }
  99. ErrorCode ConvCutlassExecution::onResize(const std::vector<Tensor*> &inputs, const std::vector<Tensor*> &outputs) {
  100. auto runtime = static_cast<CUDABackend*>(backend())->getCUDARuntime();
  101. auto input = inputs[0], output = outputs[0];
  102. const int UNIT = PACK_NUMBER;
  103. auto convCommon = mOp->main_as_Convolution2D()->common();
  104. auto pads = ConvolutionCommon::convolutionPadFull(input, output, mOp->main_as_Convolution2D()->common());
  105. int ic = input->channel();
  106. auto icDiv = UP_DIV(ic, UNIT);
  107. mIm2ColParamter.dilateX = convCommon->dilateX();
  108. mIm2ColParamter.dilateY = convCommon->dilateY();
  109. mIm2ColParamter.strideX = convCommon->strideX();
  110. mIm2ColParamter.strideY = convCommon->strideY();
  111. mIm2ColParamter.icDiv4 = icDiv;
  112. mIm2ColParamter.kernelX = convCommon->kernelX();
  113. mIm2ColParamter.kernelY = convCommon->kernelY();
  114. mIm2ColParamter.padX = std::get<0>(pads);
  115. mIm2ColParamter.padY = std::get<1>(pads);
  116. mIm2ColParamter.ih = input->height();
  117. mIm2ColParamter.iw = input->width();
  118. mIm2ColParamter.oh = output->height();
  119. mIm2ColParamter.ow = output->width();
  120. mIm2ColParamter.srcZStep = input->height() * input->width() * UNIT * input->batch();
  121. mIm2ColParamter.srcYStep = input->width() * UNIT;
  122. mIm2ColParamter.packCUnit = UNIT;
  123. mActivationType = convCommon->relu() ? 1 : convCommon->relu6() ? 2 : 0;
  124. //MNN_PRINT("conv size:%d-%d, %d-%d-%d, %d-%d-%d\n", mIm2ColParamter.kernelX, mIm2ColParamter.strideX, input->height(), input->width(), input->channel(), output->height(), output->width(), output->channel());
  125. int e = output->height() * output->width() * output->batch();
  126. int l = ic * mIm2ColParamter.kernelX * mIm2ColParamter.kernelY;
  127. int h = output->channel();
  128. mGemmInfo.elh[0] = e;
  129. mGemmInfo.elh[1] = l;
  130. mGemmInfo.elh[2] = h;
  131. mGemmInfo.elhPad[0] = UP_DIV(e, 8) * 8;
  132. mGemmInfo.elhPad[1] = UP_DIV(l, 8) * 8;
  133. mGemmInfo.elhPad[2] = UP_DIV(h, 8) * 8;
  134. //MNN_PRINT("Activate:%d \n", mActivationType);
  135. //MNN_PRINT("Im2Col:%d-%d-%d temp size:%zu!!!\n\n",output->width(), ic, mIm2ColParamter.kernelX, (size_t)sizeof(__half) * mMatMulParam.elhPack[0] * mMatMulParam.elhPack[1] * MATMULPACK * MATMULPACK);
  136. // When Im2Col memory size big than 2GB
  137. if(0){//(size_t)mGemmInfo.elh[0] * (size_t)mGemmInfo.elh[1] > 1024*1024*1024 && mIm2ColParamter.kernelX > 1 && mIm2ColParamter.kernelY > 1) {
  138. //printf("need im2col in block\n");
  139. mIsBlock = true;
  140. mBlockNum = 16;
  141. mGemmInfo.elh[0] = UP_DIV(mGemmInfo.elh[0], mBlockNum);
  142. }
  143. mIsConv1x1S1D1P0 = (mIm2ColParamter.kernelX == 1 && mIm2ColParamter.kernelY == 1 && \
  144. mIm2ColParamter.strideX == 1 && mIm2ColParamter.strideY == 1 && \
  145. mIm2ColParamter.dilateX == 1 && mIm2ColParamter.dilateY == 1 && \
  146. mIm2ColParamter.padX == 0 && mIm2ColParamter.padY == 0);
  147. mNeedIm2Col = !(mIsConv1x1S1D1P0 && (mFp16Infer || mFp32Infer));
  148. auto pool = static_cast<CUDABackend*>(backend())->getBufferPool();
  149. if(mNeedIm2Col) {
  150. size_t im2colBytes = 2;
  151. // Only when fp32 Im2Col convert to fp32, Fp16Fp32Mix Im2Col convert to fp16
  152. if(mFp32Infer) {
  153. im2colBytes = 4;
  154. }
  155. auto buffer = pool->alloc(im2colBytes * (size_t)mGemmInfo.elh[0] * (size_t)mGemmInfo.elhPad[1]);
  156. mIm2ColBuffer = (void*)((uint8_t*)buffer.first + buffer.second);
  157. pool->free(buffer);
  158. }
  159. mFilterAddr = mResource->mFilter;
  160. mBiasAddr = mResource->mBias;
  161. mBackendPtr = mResource->mBackend;
  162. // Call from different function
  163. if(mFp32Infer){
  164. return callCutlassGemmCudaCoreFloat32(inputs, outputs);
  165. }
  166. mGpuComputeCap = runtime->compute_capability();
  167. //MNN_PRINT("Gpu smArch is sm_%d\n", mGpuComputeCap);
  168. if(mGpuComputeCap < 70) {
  169. return callCutlassGemmCudaCoreFloat16(inputs, outputs);
  170. } else if(mGpuComputeCap < 75) {
  171. return callCutlassGemmTensorCore884(inputs, outputs);
  172. }
  173. return callCutlassGemmTensorCore(inputs, outputs);
  174. }
  175. ErrorCode ConvCutlassExecution::onExecute(const std::vector<Tensor*> &inputs, const std::vector<Tensor*> &outputs) {
  176. //MNN_PRINT("cuda convSingleInput onExecute in, inputsize:%d %d\n", (int)inputs.size(), workspace_size_);
  177. MNN_ASSERT(inputs.size() == 1);
  178. MNN_ASSERT(outputs.size() == 1);
  179. auto input = inputs[0];
  180. auto output = outputs[0];
  181. //printf("convcutlass:%p %p\n", input->deviceId(), output->deviceId());
  182. //MNN_PRINT("cutlass hw:%d-%d\n", input->height(), input->width());
  183. auto runtime = static_cast<CUDABackend*>(backend())->getCUDARuntime();
  184. const void *input_addr = (const void*)inputs[0]->deviceId();
  185. const void *filter_addr = mResource->mFilter;
  186. const void *bias_addr = mResource->mBias;
  187. auto bn = backend();
  188. void *output_addr = (void*)outputs[0]->deviceId();
  189. const int sw = mIm2ColParamter.strideX;
  190. const int sh = mIm2ColParamter.strideY;
  191. const int dw = mIm2ColParamter.dilateX;
  192. const int dh = mIm2ColParamter.dilateY;
  193. const int pw = mIm2ColParamter.padX;
  194. const int ph = mIm2ColParamter.padY;
  195. const int icDiv4 = mIm2ColParamter.icDiv4;
  196. const int iw = mIm2ColParamter.iw;
  197. const int ih = mIm2ColParamter.ih;
  198. //printf("%d-%d-%d-%d-%d, %d-%d\n", cpuIm2Col->icDiv4, cpuIm2Col->ih, cpuIm2Col->iw, cpuIm2Col->oh, cpuIm2Col->ow, eAlign, lAlign);
  199. // Im2col in Block
  200. for(int block_idx = 0; block_idx < mBlockNum; block_idx++) {
  201. if(mIsConv1x1S1D1P0 && mFp16Fp32MixInfer) {
  202. size_t maxCount = mGemmInfo.elh[0] * mGemmInfo.elhPad[1];
  203. callFloat2Half(input_addr, mIm2ColBuffer, maxCount, runtime);
  204. } else if (mNeedIm2Col) {
  205. callIm2ColPack((const void *)input_addr, (void *)mIm2ColBuffer, &mIm2ColParamter, mGemmInfo.elh[0], mGemmInfo.elh[1], \
  206. mGemmInfo.elhPad[0], mGemmInfo.elhPad[1], mPrecisonLevel, runtime);
  207. }
  208. }
  209. // Run cutlass gemm forward
  210. return runCutlassGemmFunc();
  211. }
  212. }// namespace CUDA
  213. }// namespace MNN