OneDNNConvInt8.cpp 8.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190
  1. //
  2. // OneDNNConvInt8.cpp
  3. //
  4. //
  5. #ifdef MNN_USE_ONEDNN
  6. #include "backend/cpu/OneDNNConvInt8.hpp"
  7. #include "core/ConvolutionCommon.hpp"
  8. using namespace dnnl;
  9. using tag = memory::format_tag;
  10. using dt = memory::data_type;
  11. namespace MNN {
  12. OneDNNConvInt8::~OneDNNConvInt8() {
  13. // Do nothing
  14. }
  15. Execution* OneDNNConvInt8::create(Backend* backend, const MNN::Convolution2D* convParam, const std::vector<Tensor*>& inputs, const std::vector<Tensor *> &outputs) {
  16. std::shared_ptr<OneDNNConvInt8::Resource> resource(new OneDNNConvInt8::Resource);
  17. resource->backend = backend;
  18. const auto convCommon = convParam->common();
  19. const auto kw = convCommon->kernelX();
  20. const auto kh = convCommon->kernelY();
  21. const auto ic = convCommon->inputCount();
  22. const auto oc = convCommon->outputCount();
  23. const auto strideX = convCommon->strideX();
  24. const auto strideY = convCommon->strideY();
  25. auto weights = convParam->symmetricQuan()->weight()->data();
  26. auto bias = convParam->symmetricQuan()->bias()->data();
  27. std::vector<float> scale(oc);
  28. for (auto i = 0; i < scale.size(); i++) {
  29. scale[i] = convParam->symmetricQuan()->scale()->data()[i];
  30. }
  31. const int conv_mask = 2;
  32. resource->conv_attr.set_output_scales(conv_mask, scale);
  33. if (convCommon->relu() || convCommon->relu6()) {
  34. post_ops ops;
  35. ops.append_eltwise(1.0f, algorithm::eltwise_relu, 0.0f, 0.0f);
  36. resource->conv_attr.set_post_ops(ops);
  37. }
  38. auto eng = engine(engine::kind::cpu, 0);
  39. resource->eng = eng;
  40. auto stm = stream(eng);
  41. memory::dims conv_weights_tz = {oc, ic, kh, kw};
  42. memory::dims conv_bias_tz = {oc};
  43. memory::dims conv_strides = {strideX, strideY};
  44. memory::dims conv_src_tz = {1, ic, convCommon->strideY() + (kh - 1) * convCommon->dilateY() + 1, (kw - 1) * convCommon->dilateX() + 1 + convCommon->strideX()};
  45. memory::dims conv_dst_tz = {1, oc, 2, 2};
  46. memory::dims conv_padding = {0, 0};
  47. auto user_weights_md = memory::desc({conv_weights_tz}, dt::s8, tag::oihw);
  48. auto conv_src_md = memory::desc({conv_src_tz}, dt::s8, tag::any);
  49. auto conv_weights_md = memory::desc({conv_weights_tz}, dt::s8, tag::any);
  50. auto conv_bias_md = memory::desc({conv_bias_tz}, dt::s32, tag::a);
  51. auto conv_dst_md = memory::desc({conv_dst_tz}, dt::s8, tag::any);
  52. auto conv_desc = convolution_forward::desc(prop_kind::forward_inference,
  53. algorithm::convolution_auto, conv_src_md, conv_weights_md, conv_bias_md,
  54. conv_dst_md, conv_strides, conv_padding, conv_padding);
  55. auto conv_pd = convolution_forward::primitive_desc(conv_desc, resource->conv_attr, eng);
  56. auto weightSrc = convParam->symmetricQuan()->weight()->data();
  57. resource->mWeight.reset(Tensor::createDevice<int8_t>({(int)conv_pd.weights_desc().get_size()}));
  58. resource->mBias.reset(Tensor::createDevice<int32_t>({(int)convParam->symmetricQuan()->bias()->size()}));
  59. auto res = backend->onAcquireBuffer(resource->mWeight.get(), Backend::STATIC);
  60. res = res && backend->onAcquireBuffer(resource->mBias.get(), Backend::STATIC);
  61. if (!res) {
  62. return nullptr;
  63. }
  64. std::shared_ptr<ConvolutionCommon::Int8Common> quanCommon;
  65. if (convParam->quanParameter() != nullptr) {
  66. quanCommon = ConvolutionCommon::load(convParam, backend(), false);
  67. weightSrc = quanCommon->weight.get();
  68. }
  69. auto user_weights = memory(user_weights_md, eng, (int8_t*)weightSrc);
  70. auto conv_weights = memory(conv_pd.weights_desc(), eng, resource->mWeight->host<int8_t>());
  71. auto r_pd = reorder::primitive_desc(user_weights, conv_weights);
  72. reorder(r_pd).execute(stm, user_weights, conv_weights);
  73. ::memcpy(resource->mBias->host<int32_t>(), convParam->symmetricQuan()->bias()->data(), convParam->symmetricQuan()->bias()->size() * sizeof(int32_t));
  74. resource->conv_bias = memory(conv_bias_md, eng, resource->mBias->host<int32_t>());
  75. resource->conv_weights = conv_weights;
  76. return new OneDNNConvInt8(resource, convCommon, backend);
  77. }
  78. OneDNNConvInt8::OneDNNConvInt8(std::shared_ptr<OneDNNConvInt8::Resource> resource, const MNN::Convolution2DCommon* common, Backend* bn) : CPUConvolution(common, bn) {
  79. mResource = resource;
  80. stm = stream(mResource->eng);
  81. }
  82. bool OneDNNConvInt8::onClone(Backend* bn, const Op* op, Execution** dst) {
  83. if (nullptr == dst) {
  84. return true;
  85. }
  86. auto dstExe = new OneDNNConvInt8(mResource, op->main_as_Convolution2D()->common(), bn);
  87. *dst = dstExe;
  88. return true;
  89. }
  90. ErrorCode OneDNNConvInt8::onResize(const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs) {
  91. const auto convCommon = mCommon;
  92. const auto kw = convCommon->kernelX();
  93. const auto kh = convCommon->kernelY();
  94. const auto ic = convCommon->inputCount();
  95. const auto oc = convCommon->outputCount();
  96. const auto strideX = convCommon->strideX();
  97. const auto strideY = convCommon->strideY();
  98. const auto ih = inputs[0]->height();
  99. const auto iw = inputs[0]->width();
  100. const auto oh = outputs[0]->height();
  101. const auto ow = outputs[0]->width();
  102. auto pads = ConvolutionCommon::convolutionPadFull(inputs[0], outputs[0], mCommon);
  103. memory::dims conv_src_tz = {inputs[0]->batch(), ic, ih, iw};
  104. memory::dims conv_weights_tz = {oc, ic, kh, kw};
  105. memory::dims conv_bias_tz = {oc};
  106. memory::dims conv_dst_tz = {outputs[0]->batch(), oc, oh, ow};
  107. memory::dims conv_strides = {strideX, strideY};
  108. auto user_src_md = memory::desc({conv_src_tz}, dt::s8, tag::nChw4c);
  109. auto user_weights_md = memory::desc({conv_weights_tz}, dt::s8, tag::oihw);
  110. auto user_dst_md = memory::desc({conv_dst_tz}, dt::s8, tag::nChw4c);
  111. auto conv_src_md = memory::desc({conv_src_tz}, dt::s8, tag::any);
  112. auto conv_dst_md = memory::desc({conv_dst_tz}, dt::s8, tag::any);
  113. user_src = memory(user_src_md, mResource->eng, inputs[0]->host<int8_t>());
  114. user_dst = memory(user_dst_md, mResource->eng, outputs[0]->host<int8_t>());
  115. mSrcTemp = nullptr;
  116. mDstTemp = nullptr;
  117. // Fix weight desc and bias desc
  118. auto conv_desc = convolution_forward::desc(prop_kind::forward_inference,
  119. algorithm::convolution_auto, conv_src_md, mResource->conv_weights.get_desc(), mResource->conv_bias.get_desc(),
  120. conv_dst_md, conv_strides, {std::get<1>(pads), std::get<0>(pads)}, {std::get<3>(pads), std::get<2>(pads)});
  121. auto conv_pd = convolution_forward::primitive_desc(conv_desc, mResource->conv_attr, mResource->eng);
  122. conv = convolution_forward(conv_pd);
  123. mSrcTemp = nullptr;
  124. mDstTemp = nullptr;
  125. if (conv_pd.src_desc() != user_src.get_desc()) {
  126. auto needSize = conv_pd.src_desc().get_size();
  127. mSrcTemp.reset(Tensor::createDevice<int8_t>({(int)needSize}));
  128. auto res = backend()->onAcquireBuffer(mSrcTemp.get(), Backend::DYNAMIC);
  129. if (!res) {
  130. return OUT_OF_MEMORY;
  131. }
  132. conv_src = memory(conv_pd.src_desc(), mResource->eng, mSrcTemp->host<int8_t>());
  133. }
  134. if (conv_pd.dst_desc() != user_dst.get_desc()) {
  135. auto needSize = conv_pd.dst_desc().get_size();
  136. mDstTemp.reset(Tensor::createDevice<int8_t>({(int)needSize}));
  137. auto res = backend()->onAcquireBuffer(mDstTemp.get(), Backend::DYNAMIC);
  138. if (!res) {
  139. return OUT_OF_MEMORY;
  140. }
  141. conv_dst = memory(conv_pd.dst_desc(), mResource->eng, mDstTemp->host<int8_t>());
  142. }
  143. if (nullptr != mSrcTemp) {
  144. backend()->onReleaseBuffer(mSrcTemp.get(), Backend::DYNAMIC);
  145. }
  146. if (nullptr != mDstTemp) {
  147. backend()->onReleaseBuffer(mDstTemp.get(), Backend::DYNAMIC);
  148. }
  149. return NO_ERROR;
  150. }
  151. ErrorCode OneDNNConvInt8::onExecute(const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs) {
  152. const auto input = inputs[0];
  153. auto output = outputs[0];
  154. memory conv_src_temp = user_src;
  155. if (nullptr != mSrcTemp) {
  156. auto r_pd = reorder::primitive_desc(user_src, conv_src);
  157. reorder(r_pd).execute(stm, user_src, conv_src);
  158. conv_src_temp = conv_src;
  159. }
  160. memory conv_dst_temp = user_dst;
  161. if (nullptr != mDstTemp) {
  162. conv_dst_temp = conv_dst;
  163. }
  164. conv.execute(stm, {{DNNL_ARG_SRC, conv_src_temp},
  165. {DNNL_ARG_WEIGHTS, mResource->conv_weights},
  166. {DNNL_ARG_BIAS, mResource->conv_bias},
  167. {DNNL_ARG_DST, conv_dst_temp}});
  168. if (nullptr != mDstTemp) {
  169. auto r_pd = reorder::primitive_desc(conv_dst, user_dst);
  170. reorder(r_pd).execute(stm, conv_dst, user_dst);
  171. }
  172. return NO_ERROR;
  173. }
  174. } // namespace MNN
  175. #endif