1
0

2 Commitit daa62c77c1 ... 9a085992ea

Tekijä SHA1 Viesti Päivämäärä
  王召德 9a085992ea Merge pull request #3725 from yanzhang-dev/features/imatmul-fp16 3 viikkoa sitten
  yanzhang 8e7a63d622 Add imatmul fp16 support for DenseConv 4 viikkoa sitten

+ 4 - 4
source/backend/cpu/compute/ConvolutionFloatFactory.cpp

@@ -97,12 +97,12 @@ static Execution* _createUnit(const Tensor* input, const Tensor* output, Backend
 #else
     if (cpuBackend->memoryMode() == BackendConfig::Memory_Low) {
 #ifdef MNN_KLEIDIAI_ENABLED
-	if (MNNGetCPUInfo()->sme2 && !weigthQauntInfo && cpuBackend->functions()->bytes == 4) {
+	if (MNNGetCPUInfo()->sme2 && !weightQuantInfo) {
 	    return new KleidiAIDenseConvolution(common, backend, originWeight, originWeightSize, bias, biasSize, weightQuantInfo);
 	}
-#else
-        return new DenseConvolutionTiledExecutor(common, backend, originWeight, originWeightSize, bias, biasSize, weightQuantInfo);
 #endif
+
+        return new DenseConvolutionTiledExecutor(common, backend, originWeight, originWeightSize, bias, biasSize, weightQuantInfo);
     }
 #endif
 
@@ -122,7 +122,7 @@ static Execution* _createUnit(const Tensor* input, const Tensor* output, Backend
 #endif
 
 #ifdef MNN_KLEIDIAI_ENABLED
-    if (MNNGetCPUInfo()->sme2 && !weightQuantInfo && cpuBackend->functions()->bytes == 4) {
+    if (MNNGetCPUInfo()->sme2 && !weightQuantInfo) {
 	return new KleidiAIDenseConvolution(common, backend, originWeight, originWeightSize, bias, biasSize, weightQuantInfo);
     }
 #endif

+ 50 - 8
source/backend/cpu/compute/KleidiAIDenseConvolution.cpp

@@ -9,8 +9,11 @@
 #include "backend/cpu/CPUTensorConvert.hpp"
 #include "core/Macro.h"
 #include "core/TensorUtils.hpp"
+#include "kai_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa.h"
 #include "kai_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa.h"
+#include "kai_lhs_imatmul_pack_x16p2vlx2_x16p_sme.h"
 #include "kai_lhs_imatmul_pack_x32p2vlx1_x32p_sme.h"
+#include "kai_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme.h"
 #include "kai_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme.h"
 
 namespace MNN {
@@ -26,8 +29,11 @@ static void initWeight(const T* weight, const T* bias, T* cache, T* output, cons
     if (bytes == 4) {
         kai_run_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme(outputCount, kh * kw, srcCount, outputCount * sizeof(T),
                                                             cache, bias, output);
+    } else if (bytes == 2) {
+        kai_run_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme(outputCount, kh * kw, srcCount, outputCount * sizeof(T),
+                                                            cache, bias, output);
     } else {
-        MNN_ERROR("Not fp32, should not be called here\n");
+        MNN_ERROR("Not fp32 and fp16, should not be called here\n");
         abort();
     }
 }
@@ -49,8 +55,11 @@ KleidiAIDenseConvolution::KleidiAIDenseConvolution(const Convolution2DCommon* co
     if (core->bytes == 4) {
         kai_rhs_packed_size = kai_get_rhs_packed_size_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme(
             outputCount, common->kernelY() * common->kernelX(), srcCount);
+    } else if (core->bytes == 2) {
+        kai_rhs_packed_size = kai_get_rhs_packed_size_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme(
+            outputCount, common->kernelY() * common->kernelX(), srcCount);
     } else {
-        MNN_ERROR("Not fp32, should not be called here\n");
+        MNN_ERROR("Not fp32 and fp16, should not be called here\n");
         abort();
     }
     mResource->mWeight.reset(Tensor::createDevice<uint8_t>({kai_rhs_packed_size}));
@@ -76,8 +85,17 @@ KleidiAIDenseConvolution::KleidiAIDenseConvolution(const Convolution2DCommon* co
     if (core->bytes == 4) {
         MNN::initWeight(originWeight, bias, cache->host<float>(), mResource->mWeight->host<float>(), oihwShape,
                         core->bytes);
+    } else if (core->bytes == 2) {
+        for (int i = 0; i < outputCount; i++) {
+            mResource->mBias->host<__fp16>()[i] = (__fp16)(bias[i]);
+        }
+        ConvertOIHWToHWIO(cache->host<__fp16>(), originWeight,
+                          {outputCount, srcCount, common->kernelY(), common->kernelX()});
+        kai_run_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme(
+            outputCount, common->kernelY() * common->kernelX(), srcCount, outputCount * sizeof(__fp16),
+            cache->host<__fp16>(), mResource->mBias->host<__fp16>(), mResource->mWeight->host<__fp16>());
     } else {
-        MNN_ERROR("Not fp32, should not be called here\n");
+        MNN_ERROR("Not fp32 and fp16, should not be called here\n");
         abort();
     }
 
@@ -135,8 +153,11 @@ ErrorCode KleidiAIDenseConvolutionMultiInput::onExecute(const std::vector<Tensor
     if (function->bytes == 4) {
         initWeight(source, mInputs[2]->host<float>(), cache, mTempWeight->host<float>(), inputs[1]->shape(),
                    function->bytes);
+    } else if (function->bytes == 2) {
+        initWeight(reinterpret_cast<const __fp16*>(source), mInputs[2]->host<__fp16>(),
+                   reinterpret_cast<__fp16*>(cache), mTempWeight->host<__fp16>(), inputs[1]->shape(), function->bytes);
     } else {
-        MNN_ERROR("Not fp32, should not be called here\n");
+        MNN_ERROR("Not fp32 and fp16, should not be called here\n");
         abort();
     }
     return mProxy->onExecute(mInputs, outputs);
@@ -150,8 +171,12 @@ ErrorCode KleidiAIDenseConvolutionMultiInput::onResize(const std::vector<Tensor*
         int kai_rhs_packed_size = kai_get_rhs_packed_size_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme(
             outputCount, inputs[1]->stride(1), depth);
         mTempWeight.reset(Tensor::createDevice<uint8_t>({kai_rhs_packed_size}));
+    } else if (function->bytes == 2) {
+        int kai_rhs_packed_size = kai_get_rhs_packed_size_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme(
+            outputCount, inputs[1]->stride(1), depth);
+        mTempWeight.reset(Tensor::createDevice<uint8_t>({kai_rhs_packed_size}));
     } else {
-        MNN_ERROR("Not fp32, should not be called here\n");
+        MNN_ERROR("Not fp32 and fp16, should not be called here\n");
         abort();
     }
     mTempWeightCache.reset(Tensor::createDevice<float>(
@@ -206,8 +231,11 @@ ErrorCode KleidiAIDenseConvolutionImpl::onResize(const std::vector<Tensor*>& inp
     if (core->bytes == 4) {
         mTempBufferTranspose.buffer().dim[0].extent =
             kai_get_lhs_packed_size_lhs_imatmul_pack_x32p2vlx1_x32p_sme(outputNhwSize, kernelSize, ic);
+    } else if (core->bytes == 2) {
+        mTempBufferTranspose.buffer().dim[0].extent =
+            kai_get_lhs_packed_size_lhs_imatmul_pack_x16p2vlx2_x16p_sme(outputNhwSize, kernelSize, ic);
     } else {
-        MNN_ERROR("Not fp32, should not be called here\n");
+        MNN_ERROR("Not fp32 and fp16, should not be called here\n");
         abort();
     }
     TensorUtils::setLinearLayout(&mTempBufferTranspose);
@@ -289,8 +317,16 @@ ErrorCode KleidiAIDenseConvolutionImpl::onResize(const std::vector<Tensor*>& inp
             kai_run_lhs_imatmul_pack_x32p2vlx1_x32p_sme(outputNhwSize, kernelSize, ic, table.data.data(), 0,
                                                         mPadBuffer.host<uint8_t>(),
                                                         mTempBufferTranspose.host<uint8_t>());
+        } else if (bytes == 2) {
+            int blockSize = kai_get_m_step_lhs_imatmul_pack_x16p2vlx2_x16p_sme();
+            ::memset(mPadBuffer.host<__fp16>(), 0, params.inputChannel * sizeof(__fp16));
+            auto table = IndirectionTable<__fp16>(mInputNHWC.shape(), params, mInputNHWC.host<__fp16>(),
+                                                  mPadBuffer.host<__fp16>(), blockSize);
+            kai_run_lhs_imatmul_pack_x16p2vlx2_x16p_sme(outputNhwSize, kernelSize, ic, table.data.data(), 0,
+                                                        mPadBuffer.host<uint8_t>(),
+                                                        mTempBufferTranspose.host<uint8_t>());
         } else {
-            MNN_ERROR("Not fp32, should not be called here\n");
+            MNN_ERROR("Not fp32 and fp16, should not be called here\n");
             abort();
         }
 
@@ -300,8 +336,14 @@ ErrorCode KleidiAIDenseConvolutionImpl::onResize(const std::vector<Tensor*>& inp
                 outputNhwSize, outputChannel, kernelSize, ic, mTempBufferTranspose.host<uint8_t>(),
                 weight->host<uint8_t>(), mOutputNHWC.host<uint8_t>(), outputChannel * sizeof(float), postParameters[2],
                 postParameters[3]);
+        } else if (bytes == 2) {
+            float max = postParameters[3] > 65504.f ? 65504.f : postParameters[3];
+            kai_run_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa(
+                outputNhwSize, outputChannel, kernelSize, ic, mTempBufferTranspose.host<uint8_t>(),
+                weight->host<uint8_t>(), mOutputNHWC.host<uint8_t>(), outputChannel * sizeof(__fp16), postParameters[2],
+                max);
         } else {
-            MNN_ERROR("Not fp32, should not be called here\n");
+            MNN_ERROR("Not fp32 and fp16, should not be called here\n");
             abort();
         }