Prechádzať zdrojové kódy

[MNN:Sync] Sync Internal 2.7.1

zhaode.wzd 1 rok pred
rodič
commit
bdf15442f4
100 zmenil súbory, kde vykonal 1864 pridanie a 963 odobranie
  1. 1 1
      docs/contribute/backend.md
  2. 101 28
      docs/inference/module.md
  3. 15 4
      express/Expr.cpp
  4. 1 0
      express/Utils.cpp
  5. 1 1
      include/MNN/MNNDefine.h
  6. 1 0
      include/MNN/Tensor.hpp
  7. 18 7
      package_scripts/win/build_lib_release.ps1
  8. 1 1
      pymnn/pip_package/build_deps.py
  9. 1 1
      pymnn/src/MNN.cc
  10. 2 2
      source/backend/arm82/Arm82Backend.cpp
  11. 3 3
      source/backend/arm82/asm/arm64/low_memory/MNNPackedMatMulFP16_int4.S
  12. 4 4
      source/backend/arm82/asm/arm64/low_memory/MNNPackedMatMulRemainFP16_int4.S
  13. 10 5
      source/backend/coreml/backend/CoreMLBackend.cpp
  14. 3 2
      source/backend/coreml/backend/CoreMLBackend.hpp
  15. 1 1
      source/backend/coreml/execution/CoreMLConvolution.cpp
  16. 8 8
      source/backend/cpu/CPUBackend.cpp
  17. 3 3
      source/backend/cpu/CPUBackend.hpp
  18. 24 0
      source/backend/cpu/CPUCast.cpp
  19. 1 1
      source/backend/cpu/CPUConvolution.cpp
  20. 1 1
      source/backend/cpu/CPUConvolutionDepthwise.cpp
  21. 3 3
      source/backend/cpu/CPUDeconvolution.cpp
  22. 1 1
      source/backend/cpu/CPUDeconvolutionDepthwise.cpp
  23. 2 2
      source/backend/cpu/CPURaster.cpp
  24. 4 2
      source/backend/cpu/CPUTensorConvert.cpp
  25. 1 1
      source/backend/cpu/OneDNNConvInt8.cpp
  26. 6 6
      source/backend/cpu/arm/arm64/low_memory/MNNPackedMatMulRemain_int4.S
  27. 3 3
      source/backend/cpu/arm/arm64/low_memory/MNNPackedMatMul_int4.S
  28. 2 2
      source/backend/cpu/bf16/BF16Backend.cpp
  29. 1 1
      source/backend/cpu/compute/ConvolutionFloatFactory.cpp
  30. 2 2
      source/backend/cpu/compute/DeconvolutionWithStride.cpp
  31. 1 1
      source/backend/cpu/compute/DenseConvolutionTiledExecutor.cpp
  32. 0 2
      source/backend/cpu/compute/SparseConvolutionTiledExecutor.cpp
  33. 91 67
      source/backend/cpu/x86_x64/avx/GemmFunction.hpp
  34. 67 49
      source/backend/cpu/x86_x64/sse/GemmFunction.hpp
  35. 1 1
      source/backend/cuda/core/CUDABackend.cpp
  36. 2 1
      source/backend/cuda/core/CUDABackend.hpp
  37. 6 2
      source/backend/cuda/core/runtime/CUDARuntime.hpp
  38. 1 1
      source/backend/cuda/execution/ConvCutlassExecution.cu
  39. 1 1
      source/backend/cuda/execution/ConvDepthWiseExecution.cu
  40. 1 1
      source/backend/cuda/execution/ConvWinogradExecution.cu
  41. 1 1
      source/backend/cuda/execution/DeconvSingleInputExecution.cu
  42. 36 22
      source/backend/cuda/execution/TopKV2Execution.cu
  43. 376 229
      source/backend/cuda/execution/Transpose.cu
  44. 1 1
      source/backend/cuda/execution/bf16/ConvCutlassBf16Execution.cu
  45. 1 1
      source/backend/cuda/execution/int8/ConvInt8CutlassExecution.cu
  46. 10 10
      source/backend/hiai/backend/NPUBackend.cpp
  47. 3 2
      source/backend/hiai/backend/NPUBackend.hpp
  48. 1 1
      source/backend/hiai/execution/NPUConvolution.cpp
  49. 1 1
      source/backend/hiai/execution/NPUConvolutionDepthwise.cpp
  50. 2 1
      source/backend/metal/MetalBackend.hpp
  51. 2 1
      source/backend/metal/MetalBackend.mm
  52. 1 1
      source/backend/metal/MetalConvolutionCommon.mm
  53. 1 1
      source/backend/metal/MetalDeconvolution.mm
  54. 3 2
      source/backend/nnapi/backend/NNAPIBackend.cpp
  55. 3 2
      source/backend/nnapi/backend/NNAPIBackend.hpp
  56. 1 1
      source/backend/nnapi/execution/NNAPIConvolution.cpp
  57. 2 1
      source/backend/opencl/core/OpenCLBackend.cpp
  58. 2 1
      source/backend/opencl/core/OpenCLBackend.hpp
  59. 1 1
      source/backend/opencl/execution/buffer/BinaryBufExecution.cpp
  60. 2 2
      source/backend/opencl/execution/buffer/ConvBufExecution.cpp
  61. 1 1
      source/backend/opencl/execution/buffer/ConvBufWinograd.cpp
  62. 1 1
      source/backend/opencl/execution/buffer/ConvSubgroupBufExecution.cpp
  63. 1 1
      source/backend/opencl/execution/buffer/DeconvBufExecution.cpp
  64. 1 1
      source/backend/opencl/execution/buffer/DepthwiseConvBufExecution.cpp
  65. 1 1
      source/backend/opencl/execution/buffer/DepthwiseConvSubgroupBufExecution.cpp
  66. 121 11
      source/backend/opencl/execution/buffer/LoopBufExecution.cpp
  67. 13 0
      source/backend/opencl/execution/buffer/LoopBufExecution.hpp
  68. 39 46
      source/backend/opencl/execution/buffer/SoftmaxBufExecution.cpp
  69. 2 1
      source/backend/opencl/execution/buffer/SoftmaxBufExecution.hpp
  70. 1 0
      source/backend/opencl/execution/cl/binary.cl
  71. 1 0
      source/backend/opencl/execution/cl/binary_buf.cl
  72. 64 1
      source/backend/opencl/execution/cl/loop.cl
  73. 56 95
      source/backend/opencl/execution/cl/loop_buf.cl
  74. 8 16
      source/backend/opencl/execution/cl/matmul_buf.cl
  75. 7 7
      source/backend/opencl/execution/cl/opencl_program.cc
  76. 249 105
      source/backend/opencl/execution/cl/softmax.cl
  77. 256 108
      source/backend/opencl/execution/cl/softmax_buf.cl
  78. 1 1
      source/backend/opencl/execution/image/ConvExecution.cpp
  79. 1 1
      source/backend/opencl/execution/image/ConvWinograd.cpp
  80. 1 1
      source/backend/opencl/execution/image/DeconvExecution.cpp
  81. 1 1
      source/backend/opencl/execution/image/DepthwiseConvExecution.cpp
  82. 1 1
      source/backend/opencl/execution/image/DepthwiseDeconvExecution.cpp
  83. 1 1
      source/backend/opencl/execution/image/EltwiseExecution.cpp
  84. 121 2
      source/backend/opencl/execution/image/LoopExecution.cpp
  85. 12 0
      source/backend/opencl/execution/image/LoopExecution.hpp
  86. 37 43
      source/backend/opencl/execution/image/SoftmaxExecution.cpp
  87. 2 1
      source/backend/opencl/execution/image/SoftmaxExecution.hpp
  88. 2 1
      source/backend/tensorrt/backend/TRTBackend.cpp
  89. 2 1
      source/backend/tensorrt/backend/TRTBackend.hpp
  90. 1 1
      source/backend/tensorrt/execution/TRTConvolution.cpp
  91. 1 1
      source/backend/tensorrt/execution/TRTDeconvolution.cpp
  92. 1 1
      source/backend/tensorrt/execution/TRTDepthwiseConvolution.cpp
  93. 1 1
      source/backend/tensorrt/execution/TRTDepthwiseDeconvolution.cpp
  94. 2 1
      source/backend/vulkan/buffer/backend/VulkanBackend.cpp
  95. 2 1
      source/backend/vulkan/buffer/backend/VulkanBackend.hpp
  96. 1 1
      source/backend/vulkan/buffer/execution/VulkanConvolution.cpp
  97. 1 1
      source/backend/vulkan/buffer/execution/VulkanDeconvolution.cpp
  98. 2 1
      source/backend/vulkan/image/backend/VulkanBackend.cpp
  99. 2 1
      source/backend/vulkan/image/backend/VulkanBackend.hpp
  100. 0 0
      source/backend/vulkan/image/execution/VulkanConvolution.cpp

+ 1 - 1
docs/contribute/backend.md

@@ -177,7 +177,7 @@ virtual void onResizeBegin();
 /**
  * @brief callback after resize ops.
  */
-virtual void onResizeEnd();
+virtual ErrorCode onResizeEnd();
 /**
  * @brief callback before executing ops.
  */

+ 101 - 28
docs/inference/module.md

@@ -10,34 +10,52 @@
 - `VARP` 作为`Module`的输入输出,也是[Expr API](expr.md)中的基础数据结构
 
 ## 工作流程
-创建Executor(可选) -> 创建Module -> 创建输入VARP -> 使用Module::forwad推理 -> 使用输出VARP -> 销毁Module -> 销毁Executor(可选)
-### 创建Executor
+配置Executor(可选) -> 创建 RuntimeManager(可选) -> 创建Module -> 创建输入VARP -> 使用Module::forwad推理 -> 使用输出VARP -> 销毁Module
+### (可选)配置Executor
 `Executor`给用户提供接口来配置推理后端、线程数等属性,以及做性能统计、算子执行的回调函数、内存回收等功能。 提供一个全局的Exector对象,用户不用创建或持有对象即可直接使用。
 ```cpp
-// 新建Exector
-NNForwardType type = MNN_FORWARD_CPU;
-MNN::BackendConfig backend_config;    // default backend config 
-std::shared_ptr<MNN::Express::Executor> executor(
-    MNN::Express::Executor::newExecutor(type, backend_config, 4));
-MNN::Express::ExecutorScope scope(executor);
-// 使用默认全局Exector
+// 配置默认全局Exector
 MNN::BackendConfig backend_config;    // default backend config 
-MNN::Express::Executor::getGlobalExecutor()->setGlobalExecutorConfig(type, backend_config, 4);
+// 设置使用4线程+CPU
+MNN::Express::Executor::getGlobalExecutor()->setGlobalExecutorConfig(MNN_FORWARD_CPU, backend_config, 4);
 ``` 
+
+### (可选)创建 RuntimeManager
+Executor 的配置会同时影响Module和表达式计算的后端配置。
+
+*** 如下示例会触发表达式计算,若 Executor 设置为 OPENCL ,则该计算会用OpenCL后端实现
+```cpp
+MNN::Express::VARP X;
+MNN::Express::VARP Y = MNN::Express::_Sign(X);
+float* yPtr = Y->readMap<float>();
+```
+
+若希望仅在该Module中采用某种后端配置(比如Module使用GPU但表达式计算使用CPU),可额外创建 RuntimeManager ,并在创建 Module 时传入
+```cpp
+MNN::ScheduleConfig sConfig;
+sConfig.type = MNN_FORWARD_OPENCL;
+
+std::shared_ptr<MNN::Express::Executor::RuntimeManager> rtmgr(MNN::Express::Executor::RuntimeManager::createRuntimeManager(sConfig), MNN::Express::Executor::RuntimeManager::destroy);
+rtmgr->setCache(".cachefile");
+```
+
 ### 创建Module
-`Module`可以通过制定模型,输入输出的名称,配置文件创建;也可以从现有的`Module`对象`clone`
+`Module`可以通过定模型,输入输出的名称,配置文件创建
 ```cpp
 // 从模型文件加载并创建新Module
 const std::string model_file = "/tmp/mymodule.mnn"; // model file with path
+
+// 输入名,可以为空,为空时 MNN 自动搜索模型中的输入,多输入情况下无法保证顺序,需要通过 getInfo 接口查看
 const std::vector<std::string> input_names{"input_1", "input_2", "input_3"};
+// 输出名,可以为空,为空时 MNN 自动搜索模型中的输出,多输出情况下无法保证顺序,需要通过 getInfo 接口查看
 const std::vector<std::string> output_names{"output_1"};
+
 Module::Config mdconfig; // default module config
 std::unique_ptr<Module> module; // module 
-module.reset(Module::load(input_names, output_names, model_filename.c_str(), &mdconfig));
-// 从现有Module创建新Module,可用于多进程并发
-std::unique_ptr<Module> module_shallow_copy;
-module_shallow_copy.reset(Module::clone(module.get()));
+// 若 rtMgr 为 nullptr ,Module 会使用Executor的后端配置
+module.reset(Module::load(input_names, output_names, model_filename.c_str(), rtMgr, &mdconfig));
 ```
+
 ### 获取模型信息
 调用`getInfo`函数可获取`Module`信息,可以参考代码:`tools/cpp/GetMNNInfo.cpp`,[工具](../tools/test.html#getmnninfo)
 ```cpp
@@ -57,41 +75,96 @@ struct Info {
 };
 const Info* getInfo() const;
 ```
+
 ### 执行推理
 调用`onForward`执行推理。
 
-**注意:当`Module`析构之后使用`onForward`返回的`VARP`将不可用**
-
 ```cpp
-std::vector<Express::VARP> onForward(const std::vector<Express::VARP>& inputs);
+std::vector<MNN::Express::VARP> onForward(const std::vector<MNN::Express::VARP>& inputs);
 ```
 
-## 使用Module进行模型推理
-使用Module进行推理时支持控制流算子,所以对于语音模型常用Module进行推理。示例代码:
+示例代码:
 
 ```cpp
 int dim = 224;
 std::vector<VARP> inputs(3);
-inputs[0] = _Input({1, dim}, NHWC, halide_type_of<int>());
-inputs[1] = _Input({1, dim}, NHWC, halide_type_of<int>());
-inputs[2] = _Input({1, dim}, NHWC, halide_type_of<int>());
+// 对于 tensoflow 转换过来的模型用 NHWC ,由 onnx 转换过来的模型用 NCHW
+inputs[0] = MNN::Express::_Input({1, dim}, NHWC, halide_type_of<int>());
+inputs[1] = MNN::Express::_Input({1, dim}, NHWC, halide_type_of<int>());
+inputs[2] = MNN::Express::_Input({1, dim}, NHWC, halide_type_of<int>());
 
 // 设置输入数据
 std::vector<int*> input_pointer = {inputs[0]->writeMap<int>(),
                                    inputs[1]->writeMap<int>(),
                                    inputs[2]->writeMap<int>()};
-for (int i = 0; i < inputs[0]->getInfo->size; ++i) {
+for (int i = 0; i < dim; ++i) {
     input_pointer[0] = i + 1;
     input_pointer[1] = i + 2;
     input_pointer[2] = i + 3;
 }
 // 执行推理
-std::vector<VARP> outputs  = module->onForward(inputs);
+std::vector<MNN::Express::VARP> outputs  = module->onForward(inputs);
 // 获取输出
 auto output_ptr = outputs[0]->readMap<float>();
 ```
 
-可以使用回调函数进行调试,与[runSessionWithCallBack](session.html#id19)相似。示例代码:
+## 多实例推理
+
+Module API 支持同个模型创建多个实例,分发到不同线程推理。具体步骤如下:
+
+- 【启动】主线程创建基准Module: 配置Executor(可选) -> 创建 RuntimeManager(可选) -> 创建Module
+- 【启动】创建子线程,在子线程中创建 Executor 
+- 【启动】子线程绑定该线程的Executor , Clone Module
+- 【使用】子线程绑定该线程的executor,使用 Clone 出来的 Module进行推理:创建输入VARP -> 使用Module::forwad推理 -> 使用输出VARP
+- 【结束】子线程绑定该线程的executor,销毁 Module
+- 【结束】子线程销毁 Executor ,销毁子线程
+- 【结束】主线程销毁 Module
+
+### 创建基准Module
+第一个实例的创建过程不需要变更
+
+### 每个实例新建Exector
+```cpp
+NNForwardType type = MNN_FORWARD_CPU;
+MNN::BackendConfig backend_config;    // default backend config 
+std::shared_ptr<MNN::Express::Executor> executor(
+    MNN::Express::Executor::newExecutor(type, backend_config, 1));
+```
+
+** 若一个算法流程中有多个模型运行,每份实例单独建一个 Executor 即可。
+
+### 每个实例克隆基准Module
+
+```cpp
+// 绑定这个实例的executor,这样不会与其他实例产生内存冲突
+MNN::Express::ExecutorScope scope(executor);
+std::unique_ptr<MNN::Express::Module> module_thread(MNN::Express::Module::clone(module.get()), MNN::Express::Module::destroy);
+```
+
+克隆出来的 Module 与基准 Module 共享不变的权重与常量数据,可以较大地降低新增实例若需的内存。
+
+
+### 每个实例推理
+```cpp
+// 每个实例推理之前用 ExecutorScope 绑定这个实例的 executor
+MNN::Express::ExecutorScope scope(executor);
+std::vector<VARP> inputs;
+/* 构建输入......*/
+// 执行推理
+std::vector<MNN::Express::VARP> outputs = module_thread->onForward(inputs);
+/* 使用输出......*/
+``` 
+
+### 销毁
+```cpp
+//每个实例销毁Module之前,也需要用 ExecutorScope 绑定这个实例的 executor
+MNN::Express::ExecutorScope scope(executor);
+module_thread.reset();
+```
+
+## 调试
+
+Module API 也支持使用回调函数进行调试,与[runSessionWithCallBack](session.html#id19)相似。示例代码:
 ```cpp
 MNN::TensorCallBackWithInfo beforeCallBack = [&](const std::vector<MNN::Tensor*>& ntensors, const OperatorInfo* info) {
 
@@ -114,7 +187,7 @@ MNN::TensorCallBackWithInfo callBack = [&](const std::vector<MNN::Tensor*>& nten
     return true;
 };
 
-// set callback function
+// 设置回调函数,需要是创建该 Module 时的 executor ,非多实例情况下用全局 executor 即可:
 Express::Executor::getGlobalExecutor()->setCallBack(std::move(beforeCallBack), std::move(callBack));
 
 // forward would trigger callback
@@ -126,4 +199,4 @@ std::vector<VARP> outputs  = user_module->onForward(inputs);
 - `pictureRecognition_module.cpp` 使用`Module`执行图像分类,使用`ImageProcess`进行前处理,`Expr`进行后处理
 - `pictureRecognition_batch.cpp` 使用`Module`执行图像分类,使用`ImageProcess`进行前处理,`Expr`进行后处理
 - `multithread_imgrecog.cpp` 使用`Module`多线程并发执行图像分类,使用`ImageProcess`进行前处理,`Expr`进行后处理
-- `transformerDemo.cpp` 使用`Module`执行Transformer模型推理
+- `transformerDemo.cpp` 使用`Module`执行Transformer模型推理

+ 15 - 4
express/Expr.cpp

@@ -177,7 +177,9 @@ EXPRP Expr::create(Variable::Info&& info, const void* ptr, VARP::InputType type,
     }
     expr->mInside->mContentDirty = false;
     if (memtype == COPY) {
-        ::memcpy(expr->mInside->mOutputTensors[0]->buffer().host, originPtr, dstInfo.size * dstInfo.type.bytes());
+        size_t total_size = dstInfo.size;
+        total_size *= dstInfo.type.bytes();
+        ::memcpy(expr->mInside->mOutputTensors[0]->buffer().host, originPtr, total_size);
     } else {
         expr->mInside->mOutputTensors[0]->buffer().host = (uint8_t*)originPtr;
         if (memtype == REF) {
@@ -227,6 +229,9 @@ EXPRP Expr::create(const OpT* op, std::vector<VARP> inputs, int outputSize) {
             case DataType_DT_FLOAT:
                 ptr = (void*)op->main.AsBlob()->float32s.data();
                 break;
+            case DataType_DT_BFLOAT16:
+                ptr = (void*)op->main.AsBlob()->uint8s.data();
+                break;
             default:
                 break;
         }
@@ -1081,9 +1086,15 @@ void Variable::save(const std::vector<VARP>& vars, NetT* dest) {
                 blob->dataFormat = (MNN_DATA_FORMAT)Utils::convertFormat(info.order);
                 blob->dims       = info.dim;
                 if (info.type.code == halide_type_float) {
-                    blob->dataType = DataType_DT_FLOAT;
-                    blob->float32s.resize(info.size);
-                    ::memcpy(blob->float32s.data(), ptr, info.size * sizeof(float));
+                    if (info.type.bits == 16) {
+                        blob->dataType = DataType_DT_BFLOAT16;
+                        blob->uint8s.resize(info.size * 2);
+                        ::memcpy(blob->uint8s.data(), ptr, info.size * sizeof(int16_t));
+                    } else {
+                        blob->dataType = DataType_DT_FLOAT;
+                        blob->float32s.resize(info.size);
+                        ::memcpy(blob->float32s.data(), ptr, info.size * sizeof(float));
+                    }
                 } else if (info.type.code == halide_type_int && info.type.bits == 32) {
                     blob->dataType = DataType_DT_INT32;
                     blob->int32s.resize(info.size);

+ 1 - 0
express/Utils.cpp

@@ -81,6 +81,7 @@ halide_type_t Utils::revertDataType(DataType dataType) {
     CONVERT(DataType_DT_UINT8, halide_type_of<uint8_t>(), dataType);
     CONVERT(DataType_DT_INT8, halide_type_of<int8_t>(), dataType);
     CONVERT(DataType_DT_HALF, halide_type_of<float>(), dataType);
+    CONVERT(DataType_DT_BFLOAT16, halide_type_t(halide_type_float, 16), dataType);
     return halide_type_of<float>();
 }
 Express::Dimensionformat Utils::revertFormat(int format) {

+ 1 - 1
include/MNN/MNNDefine.h

@@ -69,6 +69,6 @@ MNN_ERROR("Check failed: %s ==> %s\n", #success, #log); \
 #define STR(x) STR_IMP(x)
 #define MNN_VERSION_MAJOR 2
 #define MNN_VERSION_MINOR 7
-#define MNN_VERSION_PATCH 0
+#define MNN_VERSION_PATCH 1
 #define MNN_VERSION STR(MNN_VERSION_MAJOR) "." STR(MNN_VERSION_MINOR) "." STR(MNN_VERSION_PATCH)
 #endif /* MNNDefine_h */

+ 1 - 0
include/MNN/Tensor.hpp

@@ -227,6 +227,7 @@ public:
      * @return bytes needed to store data
      */
     int size() const;
+    size_t usize() const;
 
     /**
      * @brief calculate number of elements needed to store data taking reordering flag into account.

+ 18 - 7
package_scripts/win/build_lib_release.ps1

@@ -13,7 +13,8 @@
 Param(
     [Parameter(Mandatory=$true)][String]$path,
     [String]$backends,
-    [Switch]$x86
+    [Switch]$x86,
+    [Switch]$cibuild
 )
 
 $erroractionpreference = "stop"
@@ -25,14 +26,18 @@ mkdir -p $PACKAGE_LIB_PATH
 
 #clear and create package directory
 powershell ./schema/generate.ps1
-Remove-Item -Path $PACKAGE_PATH/include -Recurse -ErrorAction Ignore
-cp -r include $PACKAGE_PATH
-cp -r tools/cv/include/cv $PACKAGE_PATH/include
 pushd $PACKAGE_LIB_PATH
-mkdir -p Release\Dynamic\MT, Release\Dynamic\MD, Release\Static\MD, Release\Static\MT
+if ($cibuild) {
+    mkdir -p Release\Dynamic\MT
+} else {
+    Remove-Item -Path $PACKAGE_PATH/include -Recurse -ErrorAction Ignore
+    cp -r include $PACKAGE_PATH
+    cp -r tools/cv/include/cv $PACKAGE_PATH/include
+    mkdir -p Release\Dynamic\MT, Release\Dynamic\MD, Release\Static\MD, Release\Static\MT
+}
 popd
 
-$CMAKE_ARGS = "-DMNN_SEP_BUILD=OFF -DMNN_BUILD_TRAIN=ON -DMNN_BUILD_OPENCV=ON -DMNN_IMGCODECS=ON  -DMNN_OPENCL=ON -DMNN_VULKAN=ON -DMNN_AVX512=ON"
+$CMAKE_ARGS = "-DMNN_SEP_BUILD=OFF -DMNN_BUILD_TRAIN=ON -DMNN_BUILD_OPENCV=ON -DMNN_IMGCODECS=ON  -DMNN_OPENCL=ON -DMNN_VULKAN=ON -DMNN_AVX512=ON -DMNN_LOW_MEMORY=ON"
 if ($backends -ne $null) {
     Foreach ($backend in $backends.Split(",")) {
         if ($backend -eq "cuda") {
@@ -78,6 +83,12 @@ Build "cmake -G Ninja $CMAKE_ARGS -DCMAKE_BUILD_TYPE=Release -DMNN_WIN_RUNTIME_M
 cp MNN.lib, MNN.dll, MNN.pdb $PACKAGE_LIB_PATH\Release\Dynamic\MT
 rm MNN.*
 
+# cibuild just build single type for build test
+if ($cibuild) {
+    popd
+    return
+}
+
 ##### Release/Dynamic/MD ####
 log "Release/Dynamic/MD"
 Remove-Item CMakeCache.txt -ErrorAction Ignore
@@ -97,4 +108,4 @@ Remove-Item CMakeCache.txt -ErrorAction Ignore
 Build "cmake -G Ninja $CMAKE_ARGS -DCMAKE_BUILD_TYPE=Release -DMNN_WIN_RUNTIME_MT=OFF -DMNN_BUILD_SHARED_LIBS=OFF .."
 cp MNN.lib $PACKAGE_LIB_PATH\Release\Static\MD
 
-popd
+popd

+ 1 - 1
pymnn/pip_package/build_deps.py

@@ -44,7 +44,7 @@ def build_deps():
         shutil.rmtree(cmake_build_dir)
     os.makedirs(cmake_build_dir)
     os.chdir(cmake_build_dir)
-    extra_opts = '-DMNN_LOW_MEMORY=OFF'
+    extra_opts = '-DMNN_LOW_MEMORY=ON'
     extra_opts += ' -DMNN_VULKAN=ON -DMNN_VULKAN_IMAGE=OFF'
     extra_opts += ' -DMNN_OPENCL=ON'
     if IS_WINDOWS:

+ 1 - 1
pymnn/src/MNN.cc

@@ -2157,7 +2157,7 @@ static PyObject* PyMNNCVMatrix_repr(PyObject *self) {
     ((PyMNNCVMatrix *)self)->matrix->get9(mat);
     char buffer [100];
     sprintf(buffer, "[[%f\t%f\t%f]\n [%f\t%f\t%f]\n [%f\t%f\t%f]]",
-            mat[0], mat[1], mat[2], mat[3], mat[4], mat[5], mat[5], mat[6], mat[7], mat[8]);
+            mat[0], mat[1], mat[2], mat[3], mat[4], mat[5], mat[6], mat[7], mat[8]);
     return toPyObj(buffer);
 }
 // type: 0 set; 1 pre; 2 post

+ 2 - 2
source/backend/arm82/Arm82Backend.cpp

@@ -80,11 +80,11 @@ Execution* Arm82Backend::onCreate(const std::vector<Tensor*>& inputs, const std:
     return exe;
 }
 
-static int _getAliginSize(const halide_buffer_t& buffer, MNN_DATA_FORMAT format) {
+static size_t _getAliginSize(const halide_buffer_t& buffer, MNN_DATA_FORMAT format) {
     // The default data type of input tensor for arm82 backend is FLOAT32.
     // However, Arm82Backend default data type is FLOAT16, so check whether data type is FLOAT32,
     // then divide size by 2
-    int size          = sizeof(int16_t);
+    size_t size          = sizeof(int16_t);
     const int dimensions = buffer.dimensions;
     for (int i = 0; i < dimensions; i++) {
         int currentDimSize = buffer.dim[i].extent;

+ 3 - 3
source/backend/arm82/asm/arm64/low_memory/MNNPackedMatMulFP16_int4.S

@@ -53,7 +53,7 @@ LoopH:
     mov w17, #0x0f
     dup v3.16b, w17
     and v2.16b, v0.16b, v3.16b
-    mov w17, #7
+    mov w17, #8
     dup v0.16b, w17
     sub v1.16b, v1.16b, v0.16b
     sub v2.16b, v2.16b, v0.16b
@@ -145,7 +145,7 @@ LoopH:
         mov w17, #0x0f
         dup v3.16b, w17
         and v2.16b, v0.16b, v3.16b
-        mov w17, #7
+        mov w17, #8
         dup v0.16b, w17
         sub v1.16b, v1.16b, v0.16b
         sub v2.16b, v2.16b, v0.16b
@@ -347,7 +347,7 @@ LoopHRemain:
     ld1 {v21.8h}, [x20], #16 // bias
     mov w17, #0x0f
     dup v22.16b, w17
-    mov w17, #7
+    mov w17, #8
     dup v23.16b, w17
     // ld1 {v3.8h}, [x2]
     ld1 {v3.8h}, [x2], #16

+ 4 - 4
source/backend/arm82/asm/arm64/low_memory/MNNPackedMatMulRemainFP16_int4.S

@@ -69,7 +69,7 @@ LoopE8:
         ld1 {v12.8h, v13.8h}, [x14], #32 // alpha
         mov w17, #0x0f
         dup v3.16b, w17
-        mov w17, #7
+        mov w17, #8
         dup v4.16b, w17
         ld1 {v14.8h, v15.8h}, [x16], #32 // bias
         subs x12, x9, #2
@@ -382,7 +382,7 @@ blt E1
         ld1 {v24.8h, v25.8h}, [x14], #32 // alpha
         mov w17, #0x0f
         dup v30.16b, w17
-        mov w17, #7
+        mov w17, #8
         dup v31.16b, w17
         ld1 {v26.8h, v27.8h}, [x16], #32 // bias
         subs x12, x9, #2
@@ -565,7 +565,7 @@ blt E1
     // mov v4.d[1], v4.d[0]
     mov w17, #0x0f
     dup v30.8b, w17
-    mov w17, #7
+    mov w17, #8
     dup v31.8b, w17
     ld1 {v14.8h}, [x16], #16 // bias
     // mov v14.d[1], v14.d[0]
@@ -690,7 +690,7 @@ LoopE1:
         ld1 {v24.8h, v25.8h}, [x14], #32 // alpha
         mov w17, #0x0f
         dup v30.16b, w17
-        mov w17, #7
+        mov w17, #8
         dup v31.16b, w17
         ld1 {v26.8h, v27.8h}, [x16], #32 // bias
         subs x12, x9, #2

+ 10 - 5
source/backend/coreml/backend/CoreMLBackend.cpp

@@ -127,8 +127,8 @@ namespace MNN {
         mCoreMLLayerPtrs.clear();
     }
 
-    void CoreMLBackend::onResizeEnd() {
-        buildModel();
+    ErrorCode CoreMLBackend::onResizeEnd() {
+        return buildModel();
     }
 
     std::string CoreMLBackend::getTensorName(const Tensor* t) {
@@ -226,7 +226,7 @@ namespace MNN {
         }
         *describe = des;
     }
-    void CoreMLBackend::buildModel() {
+    ErrorCode CoreMLBackend::buildModel() {
         mInputTensors.resize(mInputIdxMap.size());
         mCoreMLModel_->description = create<CoreML__Specification__ModelDescription>();
         core_ml__specification__model_description__init(mCoreMLModel_->description);
@@ -257,9 +257,14 @@ namespace MNN {
         }
 #endif
         if (mCoreMLModel_->neuralnetwork->n_layers <= 0) {
-            return;
+            return NO_EXECUTION;
+        }
+        bool success = mCoreMLExecutor->compileModel(mCoreMLModel_.get());
+        if (success) {
+            return NO_ERROR;
+        } else {
+            return NO_EXECUTION;
         }
-        mCoreMLExecutor->compileModel(mCoreMLModel_.get());
     }
     void CoreMLBackend::invokeModel() const {
         if (mCoreMLModel_->neuralnetwork->n_layers <= 0) {

+ 3 - 2
source/backend/coreml/backend/CoreMLBackend.hpp

@@ -12,6 +12,7 @@
 #include <stdio.h>
 #include <map>
 #include <memory>
+#include <MNN/ErrorCode.hpp>
 #include <core/Backend.hpp>
 #include <core/Execution.hpp>
 #include <core/TensorUtils.hpp>
@@ -57,7 +58,7 @@ namespace MNN {
         virtual void onCopyBuffer(const Tensor* srcTensor, const Tensor* dstTensor) const override;
 
         virtual void onResizeBegin() override;
-        virtual void onResizeEnd() override;
+        virtual ErrorCode onResizeEnd() override;
 
     public:
         // TODO: using memory pool instead static factory
@@ -95,7 +96,7 @@ namespace MNN {
         }
         std::string getTensorName(const Tensor* t);
         void addLayer(CoreML__Specification__NeuralNetworkLayer* layer);
-        void buildModel();
+        ErrorCode buildModel();
         void invokeModel() const;
         void setIO(CoreML__Specification__FeatureDescription** describe, const Tensor* t);
         void setLayerName(CoreML__Specification__NeuralNetworkLayer* layer, std::string&& name);

+ 1 - 1
source/backend/coreml/execution/CoreMLConvolution.cpp

@@ -29,7 +29,7 @@ void CoreMLConvolution::loadWeightBias(const std::vector<Tensor *> &inputs) {
     }
     auto conv2D = mOp->main_as_Convolution2D();
     if (nullptr != conv2D->quanParameter()) {
-        quanCommon = ConvolutionCommon::load(conv2D->quanParameter(), true);
+        quanCommon = ConvolutionCommon::load(conv2D, backend(), true);
         if (nullptr == quanCommon) {
             MNN_ERROR("Memory not Enough, can't extract IDST Convolution: %s \n", mOp->name()->c_str());
         }

+ 8 - 8
source/backend/cpu/CPUBackend.cpp

@@ -247,12 +247,12 @@ void CPUBackend::onResizeBegin() {
     mDynamicAllocator->reset();
 }
 
-void CPUBackend::onResizeEnd() {
+ErrorCode CPUBackend::onResizeEnd() {
     getCache()->release();
-    mDynamicAllocator->compute();
+    return mDynamicAllocator->compute();
 }
 
-Backend::MemObj* CPUBackend::allocBuffer(int size, Tensor* dest, StorageType storageType) {
+Backend::MemObj* CPUBackend::allocBuffer(size_t size, Tensor* dest, StorageType storageType) {
     auto originMem = TensorUtils::getDescribe(dest)->mem.get();
     if (nullptr != originMem) {
         if (static_cast<CPUMemObj*>(originMem)->getSize() >= size) {
@@ -263,7 +263,7 @@ Backend::MemObj* CPUBackend::allocBuffer(int size, Tensor* dest, StorageType sto
     }
     // MNN_PRINT("Acquire size = %d\n", size);
     if (size <= 0) {
-        MNN_PRINT("Acquire buffer size = %d\n", size);
+        MNN_PRINT("Acquire buffer size = %lu\n", size);
         MNN_ASSERT(false);
         return nullptr;
     }
@@ -337,19 +337,19 @@ static OpType _getRealOpType(OpType opType) {
             return opType;
     }
 }
-int CPUBackend::getTensorSize(const Tensor* tensor, bool multiBytes) const {
+size_t CPUBackend::getTensorSize(const Tensor* tensor, bool multiBytes) const {
     auto core = mCoreFunctions;
-    int dataSize = 1;
+    size_t dataSize = 1;
     auto des = TensorUtils::getDescribe(tensor);
     for (int i = 0; i < tensor->dimensions(); i++) {
-        int currentDimSize = tensor->length(i);
+        size_t currentDimSize = tensor->length(i);
         if (des->dimensionFormat == MNN_DATA_FORMAT_NC4HW4 && 1 == i) {
             currentDimSize = UP_DIV(currentDimSize, core->pack) * core->pack;
         }
         dataSize *= currentDimSize;
     }
     if (multiBytes) {
-        int bytes = tensor->getType().bytes();
+        size_t bytes = tensor->getType().bytes();
         if (TensorUtils::getDescribe(tensor)->quantAttr != nullptr) {
             if (TensorUtils::getDescribe(tensor)->type == DataType_DT_FLOAT) {
                 bytes = 4;

+ 3 - 3
source/backend/cpu/CPUBackend.hpp

@@ -91,13 +91,13 @@ public:
     virtual void onExecuteEnd() const override;
     
     virtual void onResizeBegin() override;
-    virtual void onResizeEnd() override;
+    virtual ErrorCode onResizeEnd() override;
 
     const CoreFunctions* functions() const {
         return mCoreFunctions;
     }
     // Return element size for Tensor, conside pack
-    int getTensorSize(const Tensor* tensor, bool multiBytes = false) const;
+    size_t getTensorSize(const Tensor* tensor, bool multiBytes = false) const;
     const CoreInt8Functions* int8Functions() const {
         return mInt8CoreFunctions;
     }
@@ -139,7 +139,7 @@ public:
 
 
 protected:
-    MemObj* allocBuffer(int size, Tensor* dest,  StorageType storageType);
+    MemObj* allocBuffer(size_t size, Tensor* dest,  StorageType storageType);
     const CoreFunctions* mCoreFunctions;
     const CoreInt8Functions* mInt8CoreFunctions;
 private:

+ 24 - 0
source/backend/cpu/CPUCast.cpp

@@ -107,6 +107,27 @@ public:
         return NO_ERROR;
     }
 };
+class BF16ToFP32 : public Execution {
+public:
+    BF16ToFP32(Backend *b) : Execution(b) {
+        // nothing to do
+    }
+    virtual ~BF16ToFP32() = default;
+
+    virtual ErrorCode onExecute(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) override {
+        auto input                = inputs[0];
+        auto output               = outputs[0];
+        auto srcData              = input->host<int16_t>();
+        auto dstData              = output->host<int16_t>();
+        const auto inputDataSize  = input->elementSize();
+        MNN_ASSERT(inputDataSize == output->elementSize());
+        for (int i = 0; i < inputDataSize; i++) {
+            dstData[i * 2] = 0;
+            dstData[i * 2 + 1] = srcData[i];
+        }
+        return NO_ERROR;
+    }
+};
 class CopyExecution : public Execution {
 public:
     CopyExecution(Backend *b) : Execution(b) {
@@ -168,6 +189,9 @@ Execution *CPUCastCreator::onCreate(const std::vector<Tensor *> &inputs, const s
     if (dstT == MNN::DataType_DT_FLOAT && halide_type_of<int8_t>() == inputDataType) {
         return new CastDataType<int8_t, float>(backend);
     }
+    if (dstT == MNN::DataType_DT_FLOAT && halide_type_t(halide_type_float, 16) == inputDataType) {
+        return new BF16ToFP32(backend);
+    }
     if (dstT == MNN::DataType_DT_INT8 && halide_type_of<float>() == inputDataType) {
         return new CastDataType<float, int8_t>(backend);
     }

+ 1 - 1
source/backend/cpu/CPUConvolution.cpp

@@ -143,7 +143,7 @@ std::shared_ptr<CPUConvolution::ResourceInt8> CPUConvolution::makeResourceInt8(B
     int weightSize = 0;
     std::shared_ptr<ConvolutionCommon::Int8Common> quanCommon;
     resource->mOutputCount = outputCount;
-    if (!ConvolutionCommon::getConvInt8Parameters(convParam, quanCommon, weightSrc, weightSize, scalePtr, biasPtr)) {
+    if (!ConvolutionCommon::getConvInt8Parameters(convParam, quanCommon, backend, weightSrc, weightSize, scalePtr, biasPtr)) {
         return nullptr;
     }
     if (convParam->bias() && convParam->quanParameter()->alpha()) {

+ 1 - 1
source/backend/cpu/CPUConvolutionDepthwise.cpp

@@ -263,7 +263,7 @@ public:
         std::shared_ptr<ConvolutionCommon::Int8Common> quanCommon;
         std::unique_ptr<Tensor> externalWeightTensor, externalBiasTensor;
         if (nullptr != conv2d->quanParameter()) {
-            quanCommon = ConvolutionCommon::load(conv2d->quanParameter(), true);
+            quanCommon = ConvolutionCommon::load(conv2d, backend, true);
             // Back to float
             originWeight     = quanCommon->weightFloat.get();
             originWeightSize = quanCommon->weightFloat.size();

+ 3 - 3
source/backend/cpu/CPUDeconvolution.cpp

@@ -168,7 +168,7 @@ CPUDeconvolution::CPUDeconvolution(const Tensor* input, const Op* convOp, Backen
     auto biasPtr = _bias.data();
     auto scalePtr = _scale.data();
     
-    if (USE_EXTERNAL_DATA(conv2d)) {
+    if (USE_EXTERNAL_DATA(conv2d) && conv2d->quanParameter() == nullptr) {
         auto bytes = conv2d->external()->Get(1);
         tempWeightSize = static_cast<int>(bytes / sizeof(float));
         externalWeightTensor.reset(Tensor::createDevice<float>({tempWeightSize}));
@@ -181,10 +181,10 @@ CPUDeconvolution::CPUDeconvolution(const Tensor* input, const Op* convOp, Backen
         tempWeight = externalWeightTensor->host<float>();
     } else {
         if (CPUBackend::getDataType(input) == DataType_DT_INT8 || input->getType().bytes() == 1) {
-            ConvolutionCommon::getConvInt8Parameters(conv2d, quanCommon, quanWeightInt8, tempWeightSize, scalePtr, biasPtr);
+            ConvolutionCommon::getConvInt8Parameters(conv2d, quanCommon, backend, quanWeightInt8, tempWeightSize, scalePtr, biasPtr);
             ModeInt8 = true;
         } else {
-            ConvolutionCommon::getConvParameters(&quanCommon, conv2d, &tempWeight, &tempWeightSize);
+            ConvolutionCommon::getConvParameters(&quanCommon, backend, conv2d, &tempWeight, &tempWeightSize);
         }
     }
 

+ 1 - 1
source/backend/cpu/CPUDeconvolutionDepthwise.cpp

@@ -27,7 +27,7 @@ CPUDeconvolutionDepthwise::CPUDeconvolutionDepthwise(const Tensor* input, const
     const float* tempWeight = nullptr;
     int tempWeightSize   = 0;
     std::shared_ptr<ConvolutionCommon::Int8Common> quanCommon;
-    ConvolutionCommon::getConvParameters(&quanCommon, conv, &tempWeight, &tempWeightSize);
+    ConvolutionCommon::getConvParameters(&quanCommon, b, conv, &tempWeight, &tempWeightSize);
 
     // Reorder weight from whc -> pwhc4
     int kernelSize = depthQuad * core->pack * kw * kh;

+ 2 - 2
source/backend/cpu/CPURaster.cpp

@@ -805,7 +805,7 @@ public:
                     auto inputSize = input->elementSize();
                     auto output = mStack[cmd->indexes()->data()[0]];
                     auto bytes = input->getType().bytes();
-                    if (halide_type_float == input->getType().code) {
+                    if (halide_type_float == input->getType().code && bytes == 4) {
                         bytes = cpubackend->functions()->bytes;
                     }
                     auto proc = _selectUnitProc(bytes);
@@ -844,7 +844,7 @@ public:
                 auto inputSize = input->elementSize();
                 auto output = mStack[cmd->indexes()->data()[0]];
                 auto bytes = input->getType().bytes();
-                if (halide_type_float == input->getType().code) {
+                if (halide_type_float == input->getType().code && bytes == 4) {
                     bytes = static_cast<CPUBackend*>(backend())->functions()->bytes;
                 }
                 auto proc = _selectUnitProc(bytes);

+ 4 - 2
source/backend/cpu/CPUTensorConvert.cpp

@@ -288,9 +288,9 @@ ErrorCode CPUTensorConverter::convert(const Tensor* input, const Tensor* output,
     if (nullptr == core) {
         core = MNNGetCoreFunctions();
     }
-    int bitLength = _getBytes(core, input);
+    size_t bitLength = _getBytes(core, input);
     if (ib.dimensions <= 1 || source == dest) {
-        int dataSize = 1;
+        size_t dataSize = 1;
         for (int i = 0; i < input->dimensions(); i++) {
             int currentDimSize = input->length(i);
             if (source == MNN_DATA_FORMAT_NC4HW4 && 1 == i) {
@@ -298,6 +298,8 @@ ErrorCode CPUTensorConverter::convert(const Tensor* input, const Tensor* output,
             }
             dataSize *= currentDimSize;
         }
+        // printf("convert # dataSize, bitLength = %d, %d\n", dataSize, bitLength);
+        // fflush(stdout);
         ::memcpy(ob.host, ib.host, dataSize * bitLength);
         return NO_ERROR;
     }

+ 1 - 1
source/backend/cpu/OneDNNConvInt8.cpp

@@ -68,7 +68,7 @@ Execution* OneDNNConvInt8::create(Backend* backend, const MNN::Convolution2D* co
     }
     std::shared_ptr<ConvolutionCommon::Int8Common> quanCommon;
     if (convParam->quanParameter() != nullptr) {
-        quanCommon = ConvolutionCommon::load(convParam->quanParameter(), false);
+        quanCommon = ConvolutionCommon::load(convParam, backend(), false);
         weightSrc = quanCommon->weight.get();
     }
     auto user_weights = memory(user_weights_md, eng, (int8_t*)weightSrc);

+ 6 - 6
source/backend/cpu/arm/arm64/low_memory/MNNPackedMatMulRemain_int4.S

@@ -65,7 +65,7 @@ LoopE8:
         ld1 {v12.4s, v13.4s}, [x14], #32 // alpha
         mov w17, #0x0f
         dup v3.8b, w17
-        mov w17, #7
+        mov w17, #8
         dup v4.8b, w17
         ld1 {v14.4s, v15.4s}, [x16], #32 // bias
         subs x12, x9, #2
@@ -299,7 +299,7 @@ LoopE8:
         ld1 {v4.4s}, [x14], #16 // alpha
         mov w17, #0x0f
         dup v30.8b, w17
-        mov w17, #7
+        mov w17, #8
         dup v31.8b, w17
         ld1 {v14.4s}, [x16], #16 // bias
         subs x12, x9, #4
@@ -544,7 +544,7 @@ blt E1
         ld1 {v24.4s, v25.4s}, [x14], #32 // alpha
         mov w17, #0x0f
         dup v30.8b, w17
-        mov w17, #7
+        mov w17, #8
         dup v31.8b, w17
         ld1 {v26.4s, v27.4s}, [x16], #32 // bias
         subs x12, x9, #2
@@ -734,7 +734,7 @@ blt E1
     ld1 {v4.4s}, [x14], #16 // alpha
     mov w17, #0x0f
     dup v30.8b, w17
-    mov w17, #7
+    mov w17, #8
     dup v31.8b, w17
     ld1 {v14.4s}, [x16], #16 // bias
     subs x12, x9, #4
@@ -933,7 +933,7 @@ LoopE1:
         ld1 {v24.4s, v25.4s}, [x14], #32 // alpha
         mov w17, #0x0f
         dup v30.8b, w17
-        mov w17, #7
+        mov w17, #8
         dup v31.8b, w17
         ld1 {v26.4s, v27.4s}, [x16], #32 // bias
         subs x12, x9, #2
@@ -1039,7 +1039,7 @@ LoopE1:
     ld1 {v4.4s}, [x14], #16 // alpha
     mov w17, #0x0f
     dup v30.8b, w17
-    mov w17, #7
+    mov w17, #8
     dup v31.8b, w17
     ld1 {v14.4s}, [x16], #16 // bias
     subs x12, x9, #4

+ 3 - 3
source/backend/cpu/arm/arm64/low_memory/MNNPackedMatMul_int4.S

@@ -57,7 +57,7 @@ LoopH:
     mov w17, #0x0f
     dup v3.8b, w17
     and v2.8b, v0.8b, v3.8b
-    mov w17, #7
+    mov w17, #8
     dup v3.8b, w17
     sub v0.8b, v1.8b, v3.8b
     sub v1.8b, v2.8b, v3.8b
@@ -153,7 +153,7 @@ LoopH:
         mov w17, #0x0f
         dup v3.8b, w17
         and v2.8b, v0.8b, v3.8b
-        mov w17, #7
+        mov w17, #8
         dup v3.8b, w17
         sub v0.8b, v1.8b, v3.8b
         sub v1.8b, v2.8b, v3.8b
@@ -370,7 +370,7 @@ LoopHRemain:
     ld1 {v21.4s}, [x20], #16 // bias
     mov w17, #0x0f
     dup v22.16b, w17
-    mov w17, #7
+    mov w17, #8
     dup v23.16b, w17
     // ld1 {v3.4s}, [x2]
     ld1 {v3.8h}, [x2], #16

+ 2 - 2
source/backend/cpu/bf16/BF16Backend.cpp

@@ -48,11 +48,11 @@ Execution* BF16Backend::onCreate(const std::vector<Tensor*>& inputs, const std::
     return nullptr;
 }
 
-static int _getAliginSize(const halide_buffer_t& buffer, MNN_DATA_FORMAT format) {
+static size_t _getAliginSize(const halide_buffer_t& buffer, MNN_DATA_FORMAT format) {
     // The default data type of input tensor for arm82 backend is FLOAT32.
     // However, BF16Backend default data type is FLOAT16, so check whether data type is FLOAT32,
     // then divide size by 2
-    int size          = sizeof(int16_t);
+    size_t size          = sizeof(int16_t);
     const int dimensions = buffer.dimensions;
     for (int i = 0; i < dimensions; i++) {
         int currentDimSize = buffer.dim[i].extent;

+ 1 - 1
source/backend/cpu/compute/ConvolutionFloatFactory.cpp

@@ -102,7 +102,7 @@ Execution* ConvolutionFloatFactory::create(const std::vector<Tensor*>& inputs, c
             // The weight is storage as float sparse, but the backend don't support sparse compute, expand it
             forceFloat = true;
         }
-        quanCommon = ConvolutionCommon::load(conv2d->quanParameter(), forceFloat, lowMemory);
+        quanCommon = ConvolutionCommon::load(conv2d, backend, forceFloat, lowMemory);
         if (nullptr == quanCommon) {
             MNN_ERROR("Memory not Enough, can't extract IDST Convolution: %s \n", op->name()->c_str());
             return nullptr;

+ 2 - 2
source/backend/cpu/compute/DeconvolutionWithStride.cpp

@@ -177,7 +177,7 @@ DeconvolutionWithStride::DeconvolutionWithStride(const Tensor* input, const Op*
     int tempWeightSize   = 0;
     int srcCount = 0;
     std::shared_ptr<ConvolutionCommon::Int8Common> quanCommon;
-    ConvolutionCommon::getConvParameters(&quanCommon, conv2D, &tempWeight, &tempWeightSize);
+    ConvolutionCommon::getConvParameters(&quanCommon, b, conv2D, &tempWeight, &tempWeightSize);
     srcCount = tempWeightSize / kx / ky / outputCount;
 
     int sy = common->strideY();
@@ -270,7 +270,7 @@ void DeconvolutionWithStride::_extract(const Op* convOp) {
     int tempWeightSize   = 0;
     int srcCount = 0;
     std::shared_ptr<ConvolutionCommon::Int8Common> quanCommon;
-    ConvolutionCommon::getConvParameters(&quanCommon, conv2D, &tempWeight, &tempWeightSize);
+    ConvolutionCommon::getConvParameters(&quanCommon, backend(), conv2D, &tempWeight, &tempWeightSize);
     srcCount = tempWeightSize / kx / ky / outputCount;
     
     std::shared_ptr<Tensor> weightWrap(

+ 1 - 1
source/backend/cpu/compute/DenseConvolutionTiledExecutor.cpp

@@ -109,7 +109,7 @@ static bool _initQuantizeResource(std::shared_ptr<ConvolutionCommon::Int8Common>
         for (int i=0; i<weightLength; ++i) {
             int s0 = srcPtr[2 * i + 0];
             int s1 = srcPtr[2 * i + 1];
-            int d = (s0 + 7) * 16 + (s1 + 7);
+            int d = (s0 + 8) * 16 + (s1 + 8);
             dstPtr[i] = d;
         }
         resource->mWeight = weightLow;

+ 0 - 2
source/backend/cpu/compute/SparseConvolutionTiledExecutor.cpp

@@ -211,8 +211,6 @@ SparseConvolutionTiledExecutor::SparseConvolutionTiledExecutor(const Convolution
         weightNNZElement = optimalWeightNNZElement;
         weightBlockNumber = optimalWeightBlockNumber;
     }
-    MNN_ASSERT(weightNNZElement > 0);
-    MNN_ASSERT(weightBlockNumber > 0);
 
     mSparseIndexData.reset(new SparseIndexData(sparseBlockOC, weightNNZElement, weightBlockNumber, backend()));
 

+ 91 - 67
source/backend/cpu/x86_x64/avx/GemmFunction.hpp

@@ -832,7 +832,7 @@ static inline __m128 _load_int4x4(const uint8_t* src, __m128 alpha, __m128 bias)
     int iw2     = iw23 / 16;
     int iw3     = iw23 % 16;
     auto ws     = _mm_set_ps(iw3, iw2, iw1, iw0);
-    ws          = _mm_sub_ps(ws, _mm_set1_ps(7));
+    ws          = _mm_sub_ps(ws, _mm_set1_ps(8));
     ws          = _mm_add_ps(_mm_mul_ps(ws, alpha), bias);
     return ws;
 }
@@ -843,8 +843,8 @@ static inline __m256 _load_int4x8(const uint8_t* src, __m256 alpha, __m256 bias)
         int x = src[i];
         int a = x / 16;
         int b = x % 16;
-        w[i * 2] = a - 7;
-        w[i * 2 + 1] = b - 7;
+        w[i * 2] = a - 8;
+        w[i * 2 + 1] = b - 8;
     }
     auto w8 = LOAD8(w);
     return _mm256_add_ps(_mm256_mul_ps(w8, alpha), bias);
@@ -860,6 +860,7 @@ static void _AVX_MNNPackedMatMul_Main_int4(TYPE* C, const TYPE* A, const TYPE* f
     auto bExtraStride = parameter[5] / sizeof(TYPE);
     auto bStride      = bExtraStride + l * 4;
     auto hC4          = UP_DIV(h, 4);
+    float ws_tmp[4];
     for (int y = 0; y < hC4; ++y) {
         auto weight = B + y * bStride / 2;
         auto dst    = C + (y / 2) * cStride + 4 * (y % 2);
@@ -869,10 +870,11 @@ static void _AVX_MNNPackedMatMul_Main_int4(TYPE* C, const TYPE* A, const TYPE* f
         auto s1  = LOAD8(A + 0 * 24 + 8);
         auto s2  = LOAD8(A + 0 * 24 + 16);
         auto ws  = _load_int4x4(weight, alpha, bias);
-        auto w0  = _mm256_set1_ps(ws[0]);
-        auto w1  = _mm256_set1_ps(ws[1]);
-        auto w2  = _mm256_set1_ps(ws[2]);
-        auto w3  = _mm256_set1_ps(ws[3]);
+        _mm_storeu_ps(ws_tmp, ws);
+        auto w0  = _mm256_set1_ps(ws_tmp[0]);
+        auto w1  = _mm256_set1_ps(ws_tmp[1]);
+        auto w2  = _mm256_set1_ps(ws_tmp[2]);
+        auto w3  = _mm256_set1_ps(ws_tmp[3]);
         auto z0  = _mm256_mul_ps(s0, w0);
         auto z1  = _mm256_mul_ps(s1, w0);
         auto z2  = _mm256_mul_ps(s2, w0);
@@ -891,10 +893,11 @@ static void _AVX_MNNPackedMatMul_Main_int4(TYPE* C, const TYPE* A, const TYPE* f
             s1  = LOAD8(A + sy * 24 + 8);
             s2  = LOAD8(A + sy * 24 + 16);
             ws  = _load_int4x4(weight + sy * 2, alpha, bias);
-            w0  = _mm256_set1_ps(ws[0]);
-            w1  = _mm256_set1_ps(ws[1]);
-            w2  = _mm256_set1_ps(ws[2]);
-            w3  = _mm256_set1_ps(ws[3]);
+            _mm_storeu_ps(ws_tmp, ws);
+            w0  = _mm256_set1_ps(ws_tmp[0]);
+            w1  = _mm256_set1_ps(ws_tmp[1]);
+            w2  = _mm256_set1_ps(ws_tmp[2]);
+            w3  = _mm256_set1_ps(ws_tmp[3]);
             z0  = MNNAVXFMA(s0, w0, z0);
             z1  = MNNAVXFMA(s1, w0, z1);
             z2  = MNNAVXFMA(s2, w0, z2);
@@ -927,6 +930,7 @@ static void _AVX_MNNPackedMatMul_int4_20(TYPE* C, const TYPE* A, const uint8_t*
     auto bExtraStride = parameter[5] / sizeof(TYPE);
     auto bStride      = bExtraStride + l * 4;
     auto hC4          = UP_DIV(h, 4);
+    float ws_tmp[4];
     for (int y = 0; y < hC4; ++y) {
         auto weight = B + y * bStride / 2;
         auto dst    = C + (y / 2) * cStride + 4 * (y % 2);
@@ -936,10 +940,11 @@ static void _AVX_MNNPackedMatMul_int4_20(TYPE* C, const TYPE* A, const uint8_t*
         auto s1  = LOAD8(A + 0 * aStride + 8);
         auto s2  = EXPAND_128(LOAD4(A + 0 * aStride + 16));
         auto ws  = _load_int4x4(weight, alpha, bias);
-        auto w0  = _mm256_set1_ps(ws[0]);
-        auto w1  = _mm256_set1_ps(ws[1]);
-        auto w2  = _mm256_set1_ps(ws[2]);
-        auto w3  = _mm256_set1_ps(ws[3]);
+        _mm_storeu_ps(ws_tmp, ws);
+        auto w0  = _mm256_set1_ps(ws_tmp[0]);
+        auto w1  = _mm256_set1_ps(ws_tmp[1]);
+        auto w2  = _mm256_set1_ps(ws_tmp[2]);
+        auto w3  = _mm256_set1_ps(ws_tmp[3]);
         auto z0  = _mm256_mul_ps(s0, w0);
         auto z1  = _mm256_mul_ps(s1, w0);
         auto z2  = _mm256_mul_ps(s2, w0);
@@ -957,10 +962,11 @@ static void _AVX_MNNPackedMatMul_int4_20(TYPE* C, const TYPE* A, const uint8_t*
             s1  = LOAD8(A + sy * aStride + 8);
             s2  = EXPAND_128(LOAD4(A + sy * aStride + 16));
             ws  = _load_int4x4(weight + sy * 2, alpha, bias);
-            w0  = _mm256_set1_ps(ws[0]);
-            w1  = _mm256_set1_ps(ws[1]);
-            w2  = _mm256_set1_ps(ws[2]);
-            w3  = _mm256_set1_ps(ws[3]);
+            _mm_storeu_ps(ws_tmp, ws);
+            w0  = _mm256_set1_ps(ws_tmp[0]);
+            w1  = _mm256_set1_ps(ws_tmp[1]);
+            w2  = _mm256_set1_ps(ws_tmp[2]);
+            w3  = _mm256_set1_ps(ws_tmp[3]);
             z0  = MNNAVXFMA(s0, w0, z0);
             z1  = MNNAVXFMA(s1, w0, z1);
             z2  = MNNAVXFMA(s2, w0, z2);
@@ -991,6 +997,7 @@ static void _AVX_MNNPackedMatMul_int4_16(TYPE* C, const TYPE* A, const uint8_t*
     auto bExtraStride = parameter[5] / sizeof(TYPE);
     auto bStride      = bExtraStride + l * 4;
     auto hC4          = UP_DIV(h, 4);
+    float ws_tmp[4];
     for (int y = 0; y < hC4; ++y) {
         auto weight = B + y * bStride;
         auto dst    = C + (y / 2) * cStride + 4 * (y % 2);
@@ -999,10 +1006,11 @@ static void _AVX_MNNPackedMatMul_int4_16(TYPE* C, const TYPE* A, const uint8_t*
         auto s0  = LOAD8(A + 0 * aStride);
         auto s1  = LOAD8(A + 0 * aStride + 8);
         auto ws  = _load_int4x4(weight, alpha, bias);
-        auto w0  = _mm256_set1_ps(ws[0]);
-        auto w1  = _mm256_set1_ps(ws[1]);
-        auto w2  = _mm256_set1_ps(ws[2]);
-        auto w3  = _mm256_set1_ps(ws[3]);
+        _mm_storeu_ps(ws_tmp, ws);
+        auto w0  = _mm256_set1_ps(ws_tmp[0]);
+        auto w1  = _mm256_set1_ps(ws_tmp[1]);
+        auto w2  = _mm256_set1_ps(ws_tmp[2]);
+        auto w3  = _mm256_set1_ps(ws_tmp[3]);
         auto z0  = _mm256_mul_ps(s0, w0);
         auto z1  = _mm256_mul_ps(s1, w0);
         auto z3  = _mm256_mul_ps(s0, w1);
@@ -1015,10 +1023,11 @@ static void _AVX_MNNPackedMatMul_int4_16(TYPE* C, const TYPE* A, const uint8_t*
             s0  = LOAD8(A + sy * aStride);
             s1  = LOAD8(A + sy * aStride + 8);
             ws  = _load_int4x4(weight + sy * 2, alpha, bias);
-            w0  = _mm256_set1_ps(ws[0]);
-            w1  = _mm256_set1_ps(ws[1]);
-            w2  = _mm256_set1_ps(ws[2]);
-            w3  = _mm256_set1_ps(ws[3]);
+            _mm_storeu_ps(ws_tmp, ws);
+            w0  = _mm256_set1_ps(ws_tmp[0]);
+            w1  = _mm256_set1_ps(ws_tmp[1]);
+            w2  = _mm256_set1_ps(ws_tmp[2]);
+            w3  = _mm256_set1_ps(ws_tmp[3]);
             z0  = MNNAVXFMA(s0, w0, z0);
             z1  = MNNAVXFMA(s1, w0, z1);
             z3  = MNNAVXFMA(s0, w1, z3);
@@ -1226,6 +1235,7 @@ static void _AVX_MNNPackedMatMul_int4_4(TYPE* C, const TYPE* A, const uint8_t* B
         STORE_8(dst2 + 16, sumAvx21);
         STORE_8(dst2 + 24, sumAvx31);
     }
+    float ws_tmp[4];
     for (int y = hR; y < hC4; ++y) {
         auto weight = B + y * bStride / 2;
         auto dst    = C + (y / 2) * cStride + 4 * (y % 2);
@@ -1233,10 +1243,11 @@ static void _AVX_MNNPackedMatMul_int4_4(TYPE* C, const TYPE* A, const uint8_t* B
         auto bias   = _mm_loadu_ps(b + y * 4);
         auto s0     = LOAD4(A + 0 * aStride);
         auto ws     = _load_int4x4(weight, alpha, bias);
-        auto w0     = _mm_set1_ps(ws[0]);
-        auto w1     = _mm_set1_ps(ws[1]);
-        auto w2     = _mm_set1_ps(ws[2]);
-        auto w3     = _mm_set1_ps(ws[3]);
+        _mm_storeu_ps(ws_tmp, ws);
+        auto w0     = _mm_set1_ps(ws_tmp[0]);
+        auto w1     = _mm_set1_ps(ws_tmp[1]);
+        auto w2     = _mm_set1_ps(ws_tmp[2]);
+        auto w3     = _mm_set1_ps(ws_tmp[3]);
         auto z0     = _mm_mul_ps(s0, w0);
         auto z3     = _mm_mul_ps(s0, w1);
         auto z6     = _mm_mul_ps(s0, w2);
@@ -1245,10 +1256,11 @@ static void _AVX_MNNPackedMatMul_int4_4(TYPE* C, const TYPE* A, const uint8_t* B
         for (int sy = 1; sy < l; ++sy) {
             s0 = LOAD4(A + sy * aStride);
             ws = _load_int4x4(weight + sy * 2, alpha, bias);
-            w0 = _mm_set1_ps(ws[0]);
-            w1 = _mm_set1_ps(ws[1]);
-            w2 = _mm_set1_ps(ws[2]);
-            w3 = _mm_set1_ps(ws[3]);
+            _mm_storeu_ps(ws_tmp, ws);
+            w0 = _mm_set1_ps(ws_tmp[0]);
+            w1 = _mm_set1_ps(ws_tmp[1]);
+            w2 = _mm_set1_ps(ws_tmp[2]);
+            w3 = _mm_set1_ps(ws_tmp[3]);
             z0 = MNNSSEFMA(s0, w0, z0);
             z3 = MNNSSEFMA(s0, w1, z3);
             z6 = MNNSSEFMA(s0, w2, z6);
@@ -1666,6 +1678,7 @@ static void _AVX_MNNPackedMatMul_Main_int8(TYPE* C, const TYPE* A, const TYPE* f
     auto bExtraStride = parameter[5] / sizeof(TYPE);
     auto bStride      = bExtraStride + l * 4;
     auto hC4          = UP_DIV(h, 4);
+    float ws_tmp[4];
     for (int y = 0; y < hC4; ++y) {
         auto weight = B + y * bStride;
         auto dst    = C + (y / 2) * cStride + 4 * (y % 2);
@@ -1675,10 +1688,11 @@ static void _AVX_MNNPackedMatMul_Main_int8(TYPE* C, const TYPE* A, const TYPE* f
         auto s1  = LOAD8(A + 0 * 24 + 8);
         auto s2  = LOAD8(A + 0 * 24 + 16);
         auto ws  = _load_int8x4(weight, alpha, bias);
-        auto w0  = _mm256_set1_ps(ws[0]);
-        auto w1  = _mm256_set1_ps(ws[1]);
-        auto w2  = _mm256_set1_ps(ws[2]);
-        auto w3  = _mm256_set1_ps(ws[3]);
+        _mm_storeu_ps(ws_tmp, ws);
+        auto w0  = _mm256_set1_ps(ws_tmp[0]);
+        auto w1  = _mm256_set1_ps(ws_tmp[1]);
+        auto w2  = _mm256_set1_ps(ws_tmp[2]);
+        auto w3  = _mm256_set1_ps(ws_tmp[3]);
         auto z0  = _mm256_mul_ps(s0, w0);
         auto z1  = _mm256_mul_ps(s1, w0);
         auto z2  = _mm256_mul_ps(s2, w0);
@@ -1697,10 +1711,11 @@ static void _AVX_MNNPackedMatMul_Main_int8(TYPE* C, const TYPE* A, const TYPE* f
             s1  = LOAD8(A + sy * 24 + 8);
             s2  = LOAD8(A + sy * 24 + 16);
             ws  = _load_int8x4(weight + sy * 4, alpha, bias);
-            w0  = _mm256_set1_ps(ws[0]);
-            w1  = _mm256_set1_ps(ws[1]);
-            w2  = _mm256_set1_ps(ws[2]);
-            w3  = _mm256_set1_ps(ws[3]);
+            _mm_storeu_ps(ws_tmp, ws);
+            w0  = _mm256_set1_ps(ws_tmp[0]);
+            w1  = _mm256_set1_ps(ws_tmp[1]);
+            w2  = _mm256_set1_ps(ws_tmp[2]);
+            w3  = _mm256_set1_ps(ws_tmp[3]);
             z0  = MNNAVXFMA(s0, w0, z0);
             z1  = MNNAVXFMA(s1, w0, z1);
             z2  = MNNAVXFMA(s2, w0, z2);
@@ -1733,6 +1748,7 @@ static void _AVX_MNNPackedMatMul_int8_20(TYPE* C, const TYPE* A, const int8_t* B
     auto bExtraStride = parameter[5] / sizeof(TYPE);
     auto bStride      = bExtraStride + l * 4;
     auto hC4          = UP_DIV(h, 4);
+    float ws_tmp[4];
     for (int y = 0; y < hC4; ++y) {
         auto weight = B + y * bStride;
         auto dst    = C + (y / 2) * cStride + 4 * (y % 2);
@@ -1742,10 +1758,11 @@ static void _AVX_MNNPackedMatMul_int8_20(TYPE* C, const TYPE* A, const int8_t* B
         auto s1  = LOAD8(A + 0 * aStride + 8);
         auto s2  = EXPAND_128(LOAD4(A + 0 * aStride + 16));
         auto ws  = _load_int8x4(weight, alpha, bias);
-        auto w0  = _mm256_set1_ps(ws[0]);
-        auto w1  = _mm256_set1_ps(ws[1]);
-        auto w2  = _mm256_set1_ps(ws[2]);
-        auto w3  = _mm256_set1_ps(ws[3]);
+        _mm_storeu_ps(ws_tmp, ws);
+        auto w0  = _mm256_set1_ps(ws_tmp[0]);
+        auto w1  = _mm256_set1_ps(ws_tmp[1]);
+        auto w2  = _mm256_set1_ps(ws_tmp[2]);
+        auto w3  = _mm256_set1_ps(ws_tmp[3]);
         auto z0  = _mm256_mul_ps(s0, w0);
         auto z1  = _mm256_mul_ps(s1, w0);
         auto z2  = _mm256_mul_ps(s2, w0);
@@ -1763,10 +1780,11 @@ static void _AVX_MNNPackedMatMul_int8_20(TYPE* C, const TYPE* A, const int8_t* B
             s1  = LOAD8(A + sy * aStride + 8);
             s2  = EXPAND_128(LOAD4(A + sy * aStride + 16));
             ws  = _load_int8x4(weight + sy * 4, alpha, bias);
-            w0  = _mm256_set1_ps(ws[0]);
-            w1  = _mm256_set1_ps(ws[1]);
-            w2  = _mm256_set1_ps(ws[2]);
-            w3  = _mm256_set1_ps(ws[3]);
+            _mm_storeu_ps(ws_tmp, ws);
+            w0  = _mm256_set1_ps(ws_tmp[0]);
+            w1  = _mm256_set1_ps(ws_tmp[1]);
+            w2  = _mm256_set1_ps(ws_tmp[2]);
+            w3  = _mm256_set1_ps(ws_tmp[3]);
             z0  = MNNAVXFMA(s0, w0, z0);
             z1  = MNNAVXFMA(s1, w0, z1);
             z2  = MNNAVXFMA(s2, w0, z2);
@@ -1797,6 +1815,7 @@ static void _AVX_MNNPackedMatMul_int8_16(TYPE* C, const TYPE* A, const int8_t* B
     auto bExtraStride = parameter[5] / sizeof(TYPE);
     auto bStride      = bExtraStride + l * 4;
     auto hC4          = UP_DIV(h, 4);
+    float ws_tmp[4];
     for (int y = 0; y < hC4; ++y) {
         auto weight = B + y * bStride;
         auto dst    = C + (y / 2) * cStride + 4 * (y % 2);
@@ -1805,10 +1824,11 @@ static void _AVX_MNNPackedMatMul_int8_16(TYPE* C, const TYPE* A, const int8_t* B
         auto s0  = LOAD8(A + 0 * aStride);
         auto s1  = LOAD8(A + 0 * aStride + 8);
         auto ws  = _load_int8x4(weight, alpha, bias);
-        auto w0  = _mm256_set1_ps(ws[0]);
-        auto w1  = _mm256_set1_ps(ws[1]);
-        auto w2  = _mm256_set1_ps(ws[2]);
-        auto w3  = _mm256_set1_ps(ws[3]);
+        _mm_storeu_ps(ws_tmp, ws);
+        auto w0  = _mm256_set1_ps(ws_tmp[0]);
+        auto w1  = _mm256_set1_ps(ws_tmp[1]);
+        auto w2  = _mm256_set1_ps(ws_tmp[2]);
+        auto w3  = _mm256_set1_ps(ws_tmp[3]);
         auto z0  = _mm256_mul_ps(s0, w0);
         auto z1  = _mm256_mul_ps(s1, w0);
         auto z3  = _mm256_mul_ps(s0, w1);
@@ -1821,10 +1841,11 @@ static void _AVX_MNNPackedMatMul_int8_16(TYPE* C, const TYPE* A, const int8_t* B
             s0  = LOAD8(A + sy * aStride);
             s1  = LOAD8(A + sy * aStride + 8);
             ws  = _load_int8x4(weight + sy * 4, alpha, bias);
-            w0  = _mm256_set1_ps(ws[0]);
-            w1  = _mm256_set1_ps(ws[1]);
-            w2  = _mm256_set1_ps(ws[2]);
-            w3  = _mm256_set1_ps(ws[3]);
+            _mm_storeu_ps(ws_tmp, ws);
+            w0  = _mm256_set1_ps(ws_tmp[0]);
+            w1  = _mm256_set1_ps(ws_tmp[1]);
+            w2  = _mm256_set1_ps(ws_tmp[2]);
+            w3  = _mm256_set1_ps(ws_tmp[3]);
             z0  = MNNAVXFMA(s0, w0, z0);
             z1  = MNNAVXFMA(s1, w0, z1);
             z3  = MNNAVXFMA(s0, w1, z3);
@@ -2032,6 +2053,7 @@ static void _AVX_MNNPackedMatMul_int8_4(TYPE* C, const TYPE* A, const int8_t* B,
         STORE_8(dst2 + 16, sumAvx21);
         STORE_8(dst2 + 24, sumAvx31);
     }
+    float ws_tmp[4];
     for (int y = hR; y < hC4; ++y) {
         auto weight = B + y * bStride;
         auto dst    = C + (y / 2) * cStride + 4 * (y % 2);
@@ -2039,10 +2061,11 @@ static void _AVX_MNNPackedMatMul_int8_4(TYPE* C, const TYPE* A, const int8_t* B,
         auto bias   = _mm_loadu_ps(b + y * 4);
         auto s0     = LOAD4(A + 0 * aStride);
         auto ws     = _load_int8x4(weight, alpha, bias);
-        auto w0     = _mm_set1_ps(ws[0]);
-        auto w1     = _mm_set1_ps(ws[1]);
-        auto w2     = _mm_set1_ps(ws[2]);
-        auto w3     = _mm_set1_ps(ws[3]);
+        _mm_storeu_ps(ws_tmp, ws);
+        auto w0     = _mm_set1_ps(ws_tmp[0]);
+        auto w1     = _mm_set1_ps(ws_tmp[1]);
+        auto w2     = _mm_set1_ps(ws_tmp[2]);
+        auto w3     = _mm_set1_ps(ws_tmp[3]);
         auto z0     = _mm_mul_ps(s0, w0);
         auto z3     = _mm_mul_ps(s0, w1);
         auto z6     = _mm_mul_ps(s0, w2);
@@ -2051,10 +2074,11 @@ static void _AVX_MNNPackedMatMul_int8_4(TYPE* C, const TYPE* A, const int8_t* B,
         for (int sy = 1; sy < l; ++sy) {
             s0 = LOAD4(A + sy * aStride);
             ws = _load_int8x4(weight + sy * 4, alpha, bias);
-            w0 = _mm_set1_ps(ws[0]);
-            w1 = _mm_set1_ps(ws[1]);
-            w2 = _mm_set1_ps(ws[2]);
-            w3 = _mm_set1_ps(ws[3]);
+            _mm_storeu_ps(ws_tmp, ws);
+            w0 = _mm_set1_ps(ws_tmp[0]);
+            w1 = _mm_set1_ps(ws_tmp[1]);
+            w2 = _mm_set1_ps(ws_tmp[2]);
+            w3 = _mm_set1_ps(ws_tmp[3]);
             z0 = MNNSSEFMA(s0, w0, z0);
             z3 = MNNSSEFMA(s0, w1, z3);
             z6 = MNNSSEFMA(s0, w2, z6);

+ 67 - 49
source/backend/cpu/x86_x64/sse/GemmFunction.hpp

@@ -213,7 +213,7 @@ static inline __m128 _load_int4x4(const uint8_t* src, __m128 alpha, __m128 bias)
     int iw2     = iw23 / 16;
     int iw3     = iw23 % 16;
     auto ws     = _mm_set_ps(iw3, iw2, iw1, iw0);
-    ws          = _mm_sub_ps(ws, _mm_set1_ps(7));
+    ws          = _mm_sub_ps(ws, _mm_set1_ps(8));
     ws          = _mm_add_ps(_mm_mul_ps(ws, alpha), bias);
     return ws;
 }
@@ -226,6 +226,7 @@ static void _SSE_MNNPackedMatMul_12_int4(float* C, const float* A, const float*
     auto bExtraStride = parameter[5] / sizeof(float);
     auto bStride      = bExtraStride + l * 4;
     auto hC4          = UP_DIV(h, 4);
+    float ws_tmp[4];
     for (int y = 0; y < hC4; ++y) {
         auto weight = B + y * bStride / 2;
         auto dst    = C + y * cStride;
@@ -235,10 +236,11 @@ static void _SSE_MNNPackedMatMul_12_int4(float* C, const float* A, const float*
         auto s1  = _mm_loadu_ps(A + 0 * 12 + 4);
         auto s2  = _mm_loadu_ps(A + 0 * 12 + 8);
         auto ws  = _load_int4x4(weight, alpha, bias);
-        auto w0  = _mm_set1_ps(ws[0]);
-        auto w1  = _mm_set1_ps(ws[1]);
-        auto w2  = _mm_set1_ps(ws[2]);
-        auto w3  = _mm_set1_ps(ws[3]);
+        _mm_storeu_ps(ws_tmp, ws);
+        auto w0  = _mm_set1_ps(ws_tmp[0]);
+        auto w1  = _mm_set1_ps(ws_tmp[1]);
+        auto w2  = _mm_set1_ps(ws_tmp[2]);
+        auto w3  = _mm_set1_ps(ws_tmp[3]);
         auto z0  = _mm_mul_ps(s0, w0);
         auto z1  = _mm_mul_ps(s1, w0);
         auto z2  = _mm_mul_ps(s2, w0);
@@ -257,10 +259,11 @@ static void _SSE_MNNPackedMatMul_12_int4(float* C, const float* A, const float*
             s1  = _mm_loadu_ps(A + sy * 12 + 4);
             s2  = _mm_loadu_ps(A + sy * 12 + 8);
             ws = _load_int4x4(weight + sy * 2, alpha, bias);
-            w0 = _mm_set1_ps(ws[0]);
-            w1 = _mm_set1_ps(ws[1]);
-            w2 = _mm_set1_ps(ws[2]);
-            w3 = _mm_set1_ps(ws[3]);
+            _mm_storeu_ps(ws_tmp, ws);
+            w0 = _mm_set1_ps(ws_tmp[0]);
+            w1 = _mm_set1_ps(ws_tmp[1]);
+            w2 = _mm_set1_ps(ws_tmp[2]);
+            w3 = _mm_set1_ps(ws_tmp[3]);
             z0  = MNNSSEFMA(s0, w0, z0);
             z1  = MNNSSEFMA(s1, w0, z1);
             z2  = MNNSSEFMA(s2, w0, z2);
@@ -288,6 +291,7 @@ static void _SSE_MNNPackedMatMul_8_int4(float* C, const float* A, const uint8_t*
     auto bExtraStride = parameter[5] / sizeof(float);
     auto bStride      = bExtraStride + l * 4;
     auto hC4          = UP_DIV(h, 4);
+    float ws_tmp[4];
     for (int y = 0; y < hC4; ++y) {
         auto weight = B + y * bStride / 2;
         auto dst    = C + y * cStride;
@@ -296,10 +300,11 @@ static void _SSE_MNNPackedMatMul_8_int4(float* C, const float* A, const uint8_t*
         auto s0  = _mm_loadu_ps(A + 0 * aStride);
         auto s1  = _mm_loadu_ps(A + 0 * aStride + 4);
         auto ws  = _load_int4x4(weight, alpha, bias);
-        auto w0  = _mm_set1_ps(ws[0]);
-        auto w1  = _mm_set1_ps(ws[1]);
-        auto w2  = _mm_set1_ps(ws[2]);
-        auto w3  = _mm_set1_ps(ws[3]);
+        _mm_storeu_ps(ws_tmp, ws);
+        auto w0  = _mm_set1_ps(ws_tmp[0]);
+        auto w1  = _mm_set1_ps(ws_tmp[1]);
+        auto w2  = _mm_set1_ps(ws_tmp[2]);
+        auto w3  = _mm_set1_ps(ws_tmp[3]);
         auto z0  = _mm_mul_ps(s0, w0);
         auto z3  = _mm_mul_ps(s0, w1);
         auto z6  = _mm_mul_ps(s0, w2);
@@ -313,10 +318,11 @@ static void _SSE_MNNPackedMatMul_8_int4(float* C, const float* A, const uint8_t*
             s0  = _mm_loadu_ps(A + sy * aStride);
             s1  = _mm_loadu_ps(A + sy * aStride + 4);
             ws = _load_int4x4(weight + sy * 2, alpha, bias);
-            w0 = _mm_set1_ps(ws[0]);
-            w1 = _mm_set1_ps(ws[1]);
-            w2 = _mm_set1_ps(ws[2]);
-            w3 = _mm_set1_ps(ws[3]);
+            _mm_storeu_ps(ws_tmp, ws);
+            w0 = _mm_set1_ps(ws_tmp[0]);
+            w1 = _mm_set1_ps(ws_tmp[1]);
+            w2 = _mm_set1_ps(ws_tmp[2]);
+            w3 = _mm_set1_ps(ws_tmp[3]);
             z0  = MNNSSEFMA(s0, w0, z0);
             z3  = MNNSSEFMA(s0, w1, z3);
             z6  = MNNSSEFMA(s0, w2, z6);
@@ -339,6 +345,7 @@ static void _SSE_MNNPackedMatMul_4_int4(float* C, const float* A, const uint8_t*
     auto bExtraStride = parameter[5] / sizeof(float);
     auto bStride      = bExtraStride + l * 4;
     auto hC4          = UP_DIV(h, 4);
+    float ws_tmp[4];
     for (int y = 0; y < hC4; ++y) {
         auto weight = B + y * bStride / 2;
         auto dst    = C + y * cStride;
@@ -346,10 +353,11 @@ static void _SSE_MNNPackedMatMul_4_int4(float* C, const float* A, const uint8_t*
         auto bias   = _mm_loadu_ps(b + y * 4);
         auto s0     = _mm_loadu_ps(A + 0 * aStride);
         auto ws  = _load_int4x4(weight, alpha, bias);
-        auto w0  = _mm_set1_ps(ws[0]);
-        auto w1  = _mm_set1_ps(ws[1]);
-        auto w2  = _mm_set1_ps(ws[2]);
-        auto w3  = _mm_set1_ps(ws[3]);
+        _mm_storeu_ps(ws_tmp, ws);
+        auto w0  = _mm_set1_ps(ws_tmp[0]);
+        auto w1  = _mm_set1_ps(ws_tmp[1]);
+        auto w2  = _mm_set1_ps(ws_tmp[2]);
+        auto w3  = _mm_set1_ps(ws_tmp[3]);
         auto z0     = _mm_mul_ps(s0, w0);
         auto z3     = _mm_mul_ps(s0, w1);
         auto z6     = _mm_mul_ps(s0, w2);
@@ -358,10 +366,11 @@ static void _SSE_MNNPackedMatMul_4_int4(float* C, const float* A, const uint8_t*
         for (int sy = 1; sy < l; ++sy) {
             s0 = _mm_loadu_ps(A + sy * aStride);
             ws = _load_int4x4(weight + sy * 2, alpha, bias);
-            w0 = _mm_set1_ps(ws[0]);
-            w1 = _mm_set1_ps(ws[1]);
-            w2 = _mm_set1_ps(ws[2]);
-            w3 = _mm_set1_ps(ws[3]);
+            _mm_storeu_ps(ws_tmp, ws);
+            w0 = _mm_set1_ps(ws_tmp[0]);
+            w1 = _mm_set1_ps(ws_tmp[1]);
+            w2 = _mm_set1_ps(ws_tmp[2]);
+            w3 = _mm_set1_ps(ws_tmp[3]);
             z0 = MNNSSEFMA(s0, w0, z0);
             z3 = MNNSSEFMA(s0, w1, z3);
             z6 = MNNSSEFMA(s0, w2, z6);
@@ -435,6 +444,7 @@ static void _SSE_MNNPackedMatMul_12_int8(float* C, const float* A, const float*
     auto bExtraStride = parameter[5] / sizeof(float);
     auto bStride      = bExtraStride + l * 4;
     auto hC4          = UP_DIV(h, 4);
+    float ws_tmp[4];
     for (int y = 0; y < hC4; ++y) {
         auto weight = B + y * bStride;
         auto dst    = C + y * cStride;
@@ -444,10 +454,11 @@ static void _SSE_MNNPackedMatMul_12_int8(float* C, const float* A, const float*
         auto s1  = _mm_loadu_ps(A + 0 * 12 + 4);
         auto s2  = _mm_loadu_ps(A + 0 * 12 + 8);
         auto ws  = _load_int8x4(weight, alpha, bias);
-        auto w0  = _mm_set1_ps(ws[0]);
-        auto w1  = _mm_set1_ps(ws[1]);
-        auto w2  = _mm_set1_ps(ws[2]);
-        auto w3  = _mm_set1_ps(ws[3]);
+        _mm_storeu_ps(ws_tmp, ws);
+        auto w0  = _mm_set1_ps(ws_tmp[0]);
+        auto w1  = _mm_set1_ps(ws_tmp[1]);
+        auto w2  = _mm_set1_ps(ws_tmp[2]);
+        auto w3  = _mm_set1_ps(ws_tmp[3]);
         auto z0  = _mm_mul_ps(s0, w0);
         auto z1  = _mm_mul_ps(s1, w0);
         auto z2  = _mm_mul_ps(s2, w0);
@@ -466,10 +477,11 @@ static void _SSE_MNNPackedMatMul_12_int8(float* C, const float* A, const float*
             s1  = _mm_loadu_ps(A + sy * 12 + 4);
             s2  = _mm_loadu_ps(A + sy * 12 + 8);
             ws = _load_int8x4(weight + sy * 4, alpha, bias);
-            w0 = _mm_set1_ps(ws[0]);
-            w1 = _mm_set1_ps(ws[1]);
-            w2 = _mm_set1_ps(ws[2]);
-            w3 = _mm_set1_ps(ws[3]);
+            _mm_storeu_ps(ws_tmp, ws);
+            w0 = _mm_set1_ps(ws_tmp[0]);
+            w1 = _mm_set1_ps(ws_tmp[1]);
+            w2 = _mm_set1_ps(ws_tmp[2]);
+            w3 = _mm_set1_ps(ws_tmp[3]);
             z0  = MNNSSEFMA(s0, w0, z0);
             z1  = MNNSSEFMA(s1, w0, z1);
             z2  = MNNSSEFMA(s2, w0, z2);
@@ -497,6 +509,7 @@ static void _SSE_MNNPackedMatMul_8_int8(float* C, const float* A, const int8_t*
     auto bExtraStride = parameter[5] / sizeof(float);
     auto bStride      = bExtraStride + l * 4;
     auto hC4          = UP_DIV(h, 4);
+    float ws_tmp[4];
     for (int y = 0; y < hC4; ++y) {
         auto weight = B + y * bStride;
         auto dst    = C + y * cStride;
@@ -505,10 +518,11 @@ static void _SSE_MNNPackedMatMul_8_int8(float* C, const float* A, const int8_t*
         auto s0  = _mm_loadu_ps(A + 0 * aStride);
         auto s1  = _mm_loadu_ps(A + 0 * aStride + 4);
         auto ws  = _load_int8x4(weight, alpha, bias);
-        auto w0  = _mm_set1_ps(ws[0]);
-        auto w1  = _mm_set1_ps(ws[1]);
-        auto w2  = _mm_set1_ps(ws[2]);
-        auto w3  = _mm_set1_ps(ws[3]);
+        _mm_storeu_ps(ws_tmp, ws);
+        auto w0  = _mm_set1_ps(ws_tmp[0]);
+        auto w1  = _mm_set1_ps(ws_tmp[1]);
+        auto w2  = _mm_set1_ps(ws_tmp[2]);
+        auto w3  = _mm_set1_ps(ws_tmp[3]);
         auto z0  = _mm_mul_ps(s0, w0);
         auto z3  = _mm_mul_ps(s0, w1);
         auto z6  = _mm_mul_ps(s0, w2);
@@ -522,10 +536,11 @@ static void _SSE_MNNPackedMatMul_8_int8(float* C, const float* A, const int8_t*
             s0  = _mm_loadu_ps(A + sy * aStride);
             s1  = _mm_loadu_ps(A + sy * aStride + 4);
             ws = _load_int8x4(weight + sy * 4, alpha, bias);
-            w0 = _mm_set1_ps(ws[0]);
-            w1 = _mm_set1_ps(ws[1]);
-            w2 = _mm_set1_ps(ws[2]);
-            w3 = _mm_set1_ps(ws[3]);
+            _mm_storeu_ps(ws_tmp, ws);
+            w0 = _mm_set1_ps(ws_tmp[0]);
+            w1 = _mm_set1_ps(ws_tmp[1]);
+            w2 = _mm_set1_ps(ws_tmp[2]);
+            w3 = _mm_set1_ps(ws_tmp[3]);
             z0  = MNNSSEFMA(s0, w0, z0);
             z3  = MNNSSEFMA(s0, w1, z3);
             z6  = MNNSSEFMA(s0, w2, z6);
@@ -548,6 +563,7 @@ static void _SSE_MNNPackedMatMul_4_int8(float* C, const float* A, const int8_t*
     auto bExtraStride = parameter[5] / sizeof(float);
     auto bStride      = bExtraStride + l * 4;
     auto hC4          = UP_DIV(h, 4);
+    float ws_tmp[4];
     for (int y = 0; y < hC4; ++y) {
         auto weight = B + y * bStride;
         auto dst    = C + y * cStride;
@@ -555,10 +571,11 @@ static void _SSE_MNNPackedMatMul_4_int8(float* C, const float* A, const int8_t*
         auto bias   = _mm_loadu_ps(b + y * 4);
         auto s0     = _mm_loadu_ps(A + 0 * aStride);
         auto ws  = _load_int8x4(weight, alpha, bias);
-        auto w0  = _mm_set1_ps(ws[0]);
-        auto w1  = _mm_set1_ps(ws[1]);
-        auto w2  = _mm_set1_ps(ws[2]);
-        auto w3  = _mm_set1_ps(ws[3]);
+        _mm_storeu_ps(ws_tmp, ws);
+        auto w0  = _mm_set1_ps(ws_tmp[0]);
+        auto w1  = _mm_set1_ps(ws_tmp[1]);
+        auto w2  = _mm_set1_ps(ws_tmp[2]);
+        auto w3  = _mm_set1_ps(ws_tmp[3]);
         auto z0     = _mm_mul_ps(s0, w0);
         auto z3     = _mm_mul_ps(s0, w1);
         auto z6     = _mm_mul_ps(s0, w2);
@@ -567,10 +584,11 @@ static void _SSE_MNNPackedMatMul_4_int8(float* C, const float* A, const int8_t*
         for (int sy = 1; sy < l; ++sy) {
             s0 = _mm_loadu_ps(A + sy * aStride);
             ws = _load_int8x4(weight + sy * 4, alpha, bias);
-            w0 = _mm_set1_ps(ws[0]);
-            w1 = _mm_set1_ps(ws[1]);
-            w2 = _mm_set1_ps(ws[2]);
-            w3 = _mm_set1_ps(ws[3]);
+            _mm_storeu_ps(ws_tmp, ws);
+            w0 = _mm_set1_ps(ws_tmp[0]);
+            w1 = _mm_set1_ps(ws_tmp[1]);
+            w2 = _mm_set1_ps(ws_tmp[2]);
+            w3 = _mm_set1_ps(ws_tmp[3]);
             z0 = MNNSSEFMA(s0, w0, z0);
             z3 = MNNSSEFMA(s0, w1, z3);
             z6 = MNNSSEFMA(s0, w2, z6);

+ 1 - 1
source/backend/cuda/core/CUDABackend.cpp

@@ -301,7 +301,7 @@ Execution* CUDABackend::onCreate(const std::vector<Tensor*>& inputs, const std::
 void CUDABackend::onResizeBegin() {
 }
 
-void CUDABackend::onResizeEnd() {
+ErrorCode CUDABackend::onResizeEnd() {
 }
 
 void CUDABackend::onExecuteBegin() const {

+ 2 - 1
source/backend/cuda/core/CUDABackend.hpp

@@ -11,6 +11,7 @@
 
 #include <set>
 #include <vector>
+#include <MNN/ErrorCode.hpp>
 #include "MNN_generated.h"
 #include "backend/cuda/core/runtime/CUDARuntime.hpp"
 #include "core/Backend.hpp"
@@ -60,7 +61,7 @@ public:
     virtual Execution *onCreate(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs,
                                 const MNN::Op *op) override;
     virtual void onResizeBegin() override;
-    virtual void onResizeEnd() override;
+    virtual ErrorCode onResizeEnd() override;
 
     virtual void onExecuteBegin() const override;
     virtual void onExecuteEnd() const override;

+ 6 - 2
source/backend/cuda/core/runtime/CUDARuntime.hpp

@@ -107,6 +107,9 @@ public:
     size_t threads_num() {
         return mThreadPerBlock;
     }
+    const cudaDeviceProp& prop() const {
+        return mProp;
+    }
     int major_sm() const {
         return mProp.major;
     }
@@ -114,10 +117,11 @@ public:
         return mProp.major * 10 + mProp.minor;
     }
     size_t blocks_num(const size_t total_threads);
-    const cudaDeviceProp& prop() const {
-        return mProp;
+    const int smemPerBlock() {
+        return mProp.sharedMemPerBlock;
     }
 
+
     int selectDeviceMaxFreeMemory();
 
 private:

+ 1 - 1
source/backend/cuda/execution/ConvCutlassExecution.cu

@@ -26,7 +26,7 @@ ConvCutlassExecution::Resource::Resource(Backend* bn, const MNN::Op* op) {
     const float* filterDataPtr = nullptr;
     int weightSize = 0;
     std::shared_ptr<ConvolutionCommon::Int8Common> quanCommon;
-    ConvolutionCommon::getConvParameters(&quanCommon, conv, &filterDataPtr, &weightSize);
+    ConvolutionCommon::getConvParameters(&quanCommon, bn, conv, &filterDataPtr, &weightSize);
     auto oc = common->outputCount();
 
     int l = weightSize / oc;

+ 1 - 1
source/backend/cuda/execution/ConvDepthWiseExecution.cu

@@ -655,7 +655,7 @@ static std::shared_ptr<ConvDepthWiseExecution::Resource> _makeResource(const Op*
     const float* filterDataPtr = nullptr;
     int weightSize = 0;
     std::shared_ptr<ConvolutionCommon::Int8Common> quanCommon;
-    ConvolutionCommon::getConvParameters(&quanCommon, conv, &filterDataPtr, &weightSize);
+    ConvolutionCommon::getConvParameters(&quanCommon, bn, conv, &filterDataPtr, &weightSize);
     auto tempWeightStorage = pool->alloc(depthC * PACK_NUMBER * kernelY * kernelX * sizeof(float));
     auto tempWeight = (uint8_t*)tempWeightStorage.first + tempWeightStorage.second;
     cuda_check(cudaMemset(tempWeight, 0, depthC * PACK_NUMBER * kernelY * kernelX * sizeof(float)));

+ 1 - 1
source/backend/cuda/execution/ConvWinogradExecution.cu

@@ -67,7 +67,7 @@ ConvWinogradExecution::Resource::Resource(Backend* backend, const MNN::Op* op) {
     const float* filterDataPtr = nullptr;
     int weightSize = 0;
     std::shared_ptr<ConvolutionCommon::Int8Common> quanCommon;
-    ConvolutionCommon::getConvParameters(&quanCommon, conv, &filterDataPtr, &weightSize);
+    ConvolutionCommon::getConvParameters(&quanCommon, backend, conv, &filterDataPtr, &weightSize);
     mKernelInfo.kernelN = common->outputCount();
     mKernelInfo.kernelC = weightSize / mKernelInfo.kernelN / mKernelInfo.kernelX / mKernelInfo.kernelY;
 

+ 1 - 1
source/backend/cuda/execution/DeconvSingleInputExecution.cu

@@ -33,7 +33,7 @@ DeconvSingleInputExecution::Resource::Resource(Backend* bn, const MNN::Op* op) {
     const float* filterDataPtr = nullptr;
     int weightSize = 0;
     std::shared_ptr<ConvolutionCommon::Int8Common> quanCommon;
-    ConvolutionCommon::getConvParameters(&quanCommon, conv, &filterDataPtr, &weightSize);
+    ConvolutionCommon::getConvParameters(&quanCommon, bn, conv, &filterDataPtr, &weightSize);
     mKernelInfo.kernelN = common->outputCount();
     mKernelInfo.kernelC = weightSize / mKernelInfo.kernelN / mKernelInfo.kernelX / mKernelInfo.kernelY;
 

+ 36 - 22
source/backend/cuda/execution/TopKV2Execution.cu

@@ -1,7 +1,6 @@
 #include "TopKV2Execution.hpp"
 #include <memory>
 
-
 namespace MNN {
 namespace CUDA {
 
@@ -187,17 +186,22 @@ __global__ void GetResultAllRows(indexT * outputIndicesDevice, valueT * outputVa
 }
 
 
-int CalculateNumThreadPerBlock(const int K) {
-    int numThreadPerBlock;
-    if (K <= 48) {
-        numThreadPerBlock = 128;
-    } else if (K <= 96) {
-        numThreadPerBlock = 64;
-    } else {
-        numThreadPerBlock = 32;
-    }
+// The inequality "numThreadPerBlock * K * (sizeof(indexT) + sizeof(valueT)) <= smemPerBlock" must be guaranteed, which means numThreadPerBlock depends on K.
+template<typename indexT, typename valueT>
+int CalculateNumThreadPerBlock(const int K, const int smemPerBlock) {
+    int temp = smemPerBlock / (K * (sizeof(indexT) + sizeof(valueT)));
+    int numCalculate = std::pow(2, (std::floor(std::log2(temp))));
+    int numLimit = 1024;
+    return ALIMIN(numLimit, numCalculate);
+}
+
 
-    return numThreadPerBlock;
+// The inequality "numBlockPerRow * K * (sizeof(indexT) + sizeof(valueT)) <= smemPerBlock" must be guaranteed by restricting numElePerThread.
+template<typename indexT, typename valueT>
+int CalcualteNumElePerThread(const int K, const int numElePerRow, const int numThreadPerBlock, const int smemPerBlock) {
+    int numLimit = K;
+    int numCalculate = UP_DIV(numElePerRow, (smemPerBlock / (K * (sizeof(indexT) + sizeof(valueT))))-1);
+    return ALIMAX(numLimit,numCalculate);
 }
 
 
@@ -223,8 +227,17 @@ ErrorCode TopKV2Execution::onResize(const std::vector<Tensor *> &inputs, const s
 
     mParams.mNumElePerRow = mParams.mLengthRow;
     mParams.mNumK = outputs[0]->buffer().dim[outputs[0]->buffer().dimensions-1].extent;
-    mParams.mNumElePerThread = mParams.mNumK;
-    mParams.mNumThreadPerBlock = CalculateNumThreadPerBlock(mParams.mNumK);
+    auto smemLimit = static_cast<CUDABackend*>(backend())->getCUDARuntime()->smemPerBlock();
+    if (inputTensor->getType().code == halide_type_int && inputTensor->getType().bits == 32) {
+        mParams.mNumThreadPerBlock = CalculateNumThreadPerBlock<int, int>(mParams.mNumK, smemLimit);
+        mParams.mNumElePerThread = CalcualteNumElePerThread<int, int>(mParams.mNumK, mParams.mNumElePerRow, mParams.mNumThreadPerBlock, smemLimit);
+    } else if (static_cast<CUDABackend*>(backend())->useFp16()) {
+        mParams.mNumThreadPerBlock = CalculateNumThreadPerBlock<int, half>(mParams.mNumK, smemLimit);
+        mParams.mNumElePerThread = CalcualteNumElePerThread<int, half>(mParams.mNumK, mParams.mNumElePerRow, mParams.mNumThreadPerBlock, smemLimit);
+    } else {
+        mParams.mNumThreadPerBlock = CalculateNumThreadPerBlock<int, float>(mParams.mNumK, smemLimit);
+        mParams.mNumElePerThread = CalcualteNumElePerThread<int, float>(mParams.mNumK, mParams.mNumElePerRow, mParams.mNumThreadPerBlock, smemLimit);
+    }
     mParams.mNumElePerBlock = mParams.mNumElePerThread * mParams.mNumThreadPerBlock;
     mParams.mNumBlockPerRow = (mParams.mNumElePerRow - 1 + mParams.mNumElePerBlock) / mParams.mNumElePerBlock;
     mParams.mNumBlockFinal = mParams.mNumRow;
@@ -232,7 +245,7 @@ ErrorCode TopKV2Execution::onResize(const std::vector<Tensor *> &inputs, const s
     mParams.mNumBlockTotal = mParams.mNumBlockPerRow * mParams.mNumRow;
 
     // prepare temp buffer
-    auto pool = static_cast<CUDABackend*>(backend())->getStaticBufferPool();
+    auto pool = static_cast<CUDABackend*>(backend())->getBufferPool();
 
     if (inputTensor->getType().code == halide_type_int && inputTensor->getType().bits == 32) {
         auto bufferIndices = pool->alloc(mParams.mNumBlockTotal * mParams.mNumK * sizeof(int));
@@ -255,6 +268,7 @@ ErrorCode TopKV2Execution::onResize(const std::vector<Tensor *> &inputs, const s
         mParams.mBufferValues = (void*)((uint8_t*)bufferValues.first + bufferValues.second);
         pool->free(bufferIndices);
         pool->free(bufferValues);
+
     }
 
     return NO_ERROR;
@@ -270,25 +284,25 @@ ErrorCode TopKV2Execution::onExecute(const std::vector<Tensor *> &inputs, const
     // configure threads
     dim3 grid1 = {mParams.mNumBlockPerRow, mParams.mNumRow};
     dim3 block1 = {mParams.mNumThreadPerBlock, 1};
-    int smemSize_1 = mParams.mNumThreadPerBlock * mParams.mNumK;
+    int smemSize1 = mParams.mNumThreadPerBlock * mParams.mNumK;
     dim3 grid2 = {mParams.mNumBlockFinal};
     dim3 block2 = {mParams.mNumThreadFinal};
-    int smemSize_2 = mParams.mNumBlockPerRow * mParams.mNumK;
+    int smemSize2 = mParams.mNumBlockPerRow * mParams.mNumK;
 
     if (inputs[0]->getType().code == halide_type_int && inputs[0]->getType().bits == 32) {
-        TopKAllRows<int, int><<<grid1, block1, smemSize_1 * (sizeof(int) + sizeof(int))>>>(static_cast<const int *>(inputDeviceAddr), static_cast<int *>(mParams.mBufferIndices), static_cast<int *>(mParams.mBufferValues), mParams.mNumK, mParams.mLengthRow, mParams.mMinInt, mParams.mDescendFlag);
+        TopKAllRows<int, int><<<grid1, block1, smemSize1 * (sizeof(int) + sizeof(int))>>>(static_cast<const int *>(inputDeviceAddr), static_cast<int *>(mParams.mBufferIndices), static_cast<int *>(mParams.mBufferValues), mParams.mNumK, mParams.mLengthRow, mParams.mMinInt, mParams.mDescendFlag);
         checkKernelErrors;
-        GetResultAllRows<int, int><<<grid2, block2, smemSize_2 * (sizeof(int) + sizeof(int))>>>(static_cast<int *>(outputIndicesDeviceAddr), static_cast<int *>(outputValuesDeviceAddr), static_cast<int *>(mParams.mBufferIndices), static_cast<int *>(mParams.mBufferValues), mParams.mNumK, mParams.mNumBlockPerRow, mParams.mDescendFlag);
+        GetResultAllRows<int, int><<<grid2, block2, smemSize2 * (sizeof(int) + sizeof(int))>>>(static_cast<int *>(outputIndicesDeviceAddr), static_cast<int *>(outputValuesDeviceAddr), static_cast<int *>(mParams.mBufferIndices), static_cast<int *>(mParams.mBufferValues), mParams.mNumK, mParams.mNumBlockPerRow, mParams.mDescendFlag);
         checkKernelErrors;
     } else if (static_cast<CUDABackend*>(backend())->useFp16()) {
-        TopKAllRows<int, half><<<grid1, block1, smemSize_1 * (sizeof(float) + sizeof(int))>>>(static_cast<const half *>(inputDeviceAddr), static_cast<int *>(mParams.mBufferIndices), static_cast<half *>(mParams.mBufferValues), mParams.mNumK, mParams.mLengthRow, mParams.mMinHalf, mParams.mDescendFlag);
+        TopKAllRows<int, half><<<grid1, block1, smemSize1 * (sizeof(half) + sizeof(int))>>>(static_cast<const half *>(inputDeviceAddr), static_cast<int *>(mParams.mBufferIndices), static_cast<half *>(mParams.mBufferValues), mParams.mNumK, mParams.mLengthRow, mParams.mMinHalf, mParams.mDescendFlag);
         checkKernelErrors;
-        GetResultAllRows<int, half><<<grid2, block2, smemSize_2 * (sizeof(float) + sizeof(int))>>>(static_cast<int *>(outputIndicesDeviceAddr), static_cast<half *>(outputValuesDeviceAddr), static_cast<int *>(mParams.mBufferIndices), static_cast<half *>(mParams.mBufferValues), mParams.mNumK, mParams.mNumBlockPerRow, mParams.mDescendFlag);
+        GetResultAllRows<int, half><<<grid2, block2, smemSize2 * (sizeof(half) + sizeof(int))>>>(static_cast<int *>(outputIndicesDeviceAddr), static_cast<half *>(outputValuesDeviceAddr), static_cast<int *>(mParams.mBufferIndices), static_cast<half *>(mParams.mBufferValues), mParams.mNumK, mParams.mNumBlockPerRow, mParams.mDescendFlag);
         checkKernelErrors;
     } else {
-        TopKAllRows<int, float><<<grid1, block1, smemSize_1 * (sizeof(float) + sizeof(int))>>>(static_cast<const float *>(inputDeviceAddr), static_cast<int *>(mParams.mBufferIndices), static_cast<float *>(mParams.mBufferValues), mParams.mNumK, mParams.mLengthRow, mParams.mMinFloat, mParams.mDescendFlag);
+        TopKAllRows<int, float><<<grid1, block1, smemSize1 * (sizeof(float) + sizeof(int))>>>(static_cast<const float *>(inputDeviceAddr), static_cast<int *>(mParams.mBufferIndices), static_cast<float *>(mParams.mBufferValues), mParams.mNumK, mParams.mLengthRow, mParams.mMinFloat, mParams.mDescendFlag);
         checkKernelErrors;
-        GetResultAllRows<int, float><<<grid2, block2, smemSize_2 * (sizeof(float) + sizeof(int))>>>(static_cast<int *>(outputIndicesDeviceAddr), static_cast<float *>(outputValuesDeviceAddr), static_cast<int *>(mParams.mBufferIndices), static_cast<float *>(mParams.mBufferValues), mParams.mNumK, mParams.mNumBlockPerRow, mParams.mDescendFlag);
+        GetResultAllRows<int, float><<<grid2, block2, smemSize2 * (sizeof(float) + sizeof(int))>>>(static_cast<int *>(outputIndicesDeviceAddr), static_cast<float *>(outputValuesDeviceAddr), static_cast<int *>(mParams.mBufferIndices), static_cast<float *>(mParams.mBufferValues), mParams.mNumK, mParams.mNumBlockPerRow, mParams.mDescendFlag);
         checkKernelErrors;
     }
 

+ 376 - 229
source/backend/cuda/execution/Transpose.cu

@@ -329,200 +329,337 @@ void Transpose(uint8_t* output, const uint8_t* input, const TransposeParam* cpuP
     }
 }
 
+// for the following transpose kernels:
+// maxCount is num of threads i.e., num of elements of output format
+// inChannelPack is num of channel pack of input format
+// divOutChannelPack is Div for channel pack of output format
+
+// copy kernel
 template<typename T0, typename T1>
-__global__ void NCHW_2_NHWC8(const T0* input,
-    T1* output,
-    const int maxCount,
-    const int channel,
-    const int area,
-    const int channel_pack,
-    DivModFast d_ocp,
-    DivModFast d_area
+__global__ void NCHW_2_NCHW(const T0* input,
+                            T1* output,
+                            const int maxCount
 ) {
     for(size_t index = blockIdx.x * blockDim.x + threadIdx.x; index < maxCount; index += blockDim.x * gridDim.x) {
-        int area_idx, temp, chnlp_idx, batch_idx;
-        d_ocp.divmod(index, temp, chnlp_idx);
-        d_area.divmod(temp, batch_idx, area_idx);
+        output[index] = (T1)input[index];
+    }
+}
 
-        if(chnlp_idx >= channel) {
-            output[index] = (T1)0.0f;
-            continue;
-        }
-        int src_offset = (batch_idx * channel + chnlp_idx) * area + area_idx;
+// NHWC NCHW
+template<typename T0, typename T1>
+__global__ void NHWC_2_NCHW(const T0* input,
+                            T1* output,
+                            const int maxCount,
+                            const int channel, // redundant parameter
+                            const int area,
+                            const int inChannelPack,
+                            DivModFast divOutChannelPack,
+                            DivModFast divArea
+) {
+    for(size_t index = blockIdx.x * blockDim.x + threadIdx.x; index < maxCount; index += blockDim.x * gridDim.x) {
+        int area_idx, temp, chnl_idx, batch_idx;
+        divArea.divmod(index, temp, area_idx);
+        divOutChannelPack.divmod(temp, batch_idx, chnl_idx);
+
+        int src_offset = (batch_idx * area + area_idx) * inChannelPack+ chnl_idx;
+        output[index] = (T1)input[src_offset];
+    }
+}
+
+// NHWC8_2_NCHW
+template<typename T0, typename T1>
+__global__ void NHWC8_2_NCHW(const T0* input,
+                             T1* output,
+                             const int maxCount,
+                             const int channel, // redundant parameter
+                             const int area,
+                             const int inChannelPack,
+                             DivModFast divOutChannelPack,
+                             DivModFast divArea
+) {
+    for(size_t index = blockIdx.x * blockDim.x + threadIdx.x; index < maxCount; index += blockDim.x * gridDim.x) {
+        int area_idx, temp, chnl_idx, batch_idx;
+        divArea.divmod(index, temp, area_idx);
+        divOutChannelPack.divmod(temp, batch_idx, chnl_idx);
+
+        int src_offset = (batch_idx * area + area_idx) * inChannelPack + chnl_idx;
         output[index] = (T1)input[src_offset];
     }
 }
 
+// C4NHW4_2_NCHW
+template<typename T0, typename T1>
+__global__ void C4NHW4_2_NCHW(const T0* input,
+                             T1* output,
+                             const int maxCount,
+                             const int channel,
+                             const int area,
+                             const int inChannelPack, // redundant parameter
+                             DivModFast divOutChannelPack,
+                             DivModFast divArea
+) {
+    const int batch = (maxCount / channel) / area;
+    for(size_t index = blockIdx.x * blockDim.x + threadIdx.x; index < maxCount; index += blockDim.x * gridDim.x) {
+        int area_idx, temp, chnl_idx, batch_idx;
+        divArea.divmod(index, temp, area_idx);
+        divOutChannelPack.divmod(temp, batch_idx, chnl_idx);
+
+        int c4_idx = chnl_idx >> 2;
+        int cL_idx = chnl_idx & 3;
+        int src_offset = ((c4_idx * batch + batch_idx) * area + area_idx) * 4 + cL_idx;
+        output[index] = (T1)input[src_offset];
+    }
+}
+
+// NCHW NHWC
 template<typename T0, typename T1>
 __global__ void NCHW_2_NHWC(const T0* input,
-    T1* output,
-    const int maxCount,
-    const int channel,
-    const int area,
-    const int channel_pack,
-    DivModFast d_oc,
-    DivModFast d_area
+                            T1* output,
+                            const int maxCount,
+                            const int channel, // redundant parameter
+                            const int area,
+                            const int inChannelPack,
+                            DivModFast divOutChannelPack,
+                            DivModFast divArea
 ) {
     for(size_t index = blockIdx.x * blockDim.x + threadIdx.x; index < maxCount; index += blockDim.x * gridDim.x) {
         int area_idx, temp, chnl_idx, batch_idx;
-        d_oc.divmod(index, temp, chnl_idx);
-        d_area.divmod(temp, batch_idx, area_idx);
-        
-        int src_offset = (batch_idx * channel + chnl_idx) * area + area_idx;
+        divOutChannelPack.divmod(index, temp, chnl_idx);
+        divArea.divmod(temp, batch_idx, area_idx);
+
+        int src_offset = (batch_idx * inChannelPack + chnl_idx) * area + area_idx;
+        output[index] = (T1)input[src_offset];
+    }
+}
+
+// NHWC8 NHWC
+template<typename T0, typename T1>
+__global__ void NHWC8_2_NHWC(const T0* input,
+                             T1* output,
+                             const int maxCount,
+                             const int channel, // redundant parameter
+                             const int area,
+                             const int inChannelPack,
+                             DivModFast divOutChannelPack,
+                             DivModFast divArea
+) {
+    for(size_t index = blockIdx.x * blockDim.x + threadIdx.x; index < maxCount; index += blockDim.x * gridDim.x) {
+        int area_idx, temp, chnl_idx, batch_idx;
+        divOutChannelPack.divmod(index, temp, chnl_idx);
+        divArea.divmod(temp, batch_idx, area_idx);
+
+        int src_offset = (batch_idx * area + area_idx) * inChannelPack + chnl_idx;
+        output[index] = (T1)input[src_offset];
+    }
+}
+
+// C4NHW4 NHWC
+template<typename T0, typename T1>
+__global__ void C4NHW4_2_NHWC(const T0* input,
+                             T1* output,
+                             const int maxCount,
+                             const int channel,
+                             const int area,
+                             const int inChannelPack, // redundant parameter
+                             DivModFast divOutChannelPack,
+                             DivModFast divArea
+) {
+    const int batch = (maxCount / channel) / area;
+    for(size_t index = blockIdx.x * blockDim.x + threadIdx.x; index < maxCount; index += blockDim.x * gridDim.x) {
+        int area_idx, temp, chnl_idx, batch_idx;
+        divOutChannelPack.divmod(index, temp, chnl_idx);
+        divArea.divmod(temp, batch_idx, area_idx);
+
+        int c4_idx = chnl_idx >> 2;
+        int cL_idx = chnl_idx & 3;
+        int src_offset = ((c4_idx * batch + batch_idx) * area + area_idx) * 4 + cL_idx;
         output[index] = (T1)input[src_offset];
     }
 }
 
+// NHWC NHWC8
 template<typename T0, typename T1>
 __global__ void NHWC_2_NHWC8(const T0* input,
-    T1* output,
-    const int maxCount,
-    const int channel,
-    const int area,
-    const int channel_pack,
-    DivModFast d_ocp,
-    DivModFast d_area
+                             T1* output,
+                             const int maxCount,
+                             const int channel,
+                             const int area,
+                             const int inChannelPack,
+                             DivModFast divOutChannelPack,
+                             DivModFast divArea
 ) {
     for(size_t index = blockIdx.x * blockDim.x + threadIdx.x; index < maxCount; index += blockDim.x * gridDim.x) {
-        int area_idx, temp, chnlp_idx, batch_idx;
-        d_ocp.divmod(index, temp, chnlp_idx);
-        d_area.divmod(temp, batch_idx, area_idx);
+        int area_idx, temp, chnl_idx, batch_idx;
+        divOutChannelPack.divmod(index, temp, chnl_idx);
+        divArea.divmod(temp, batch_idx, area_idx);
 
-        if(chnlp_idx >= channel) {
+        if(chnl_idx >= channel) {
             output[index] = (T1)0.0f;
             continue;
         }
-        int src_offset = (batch_idx * area + area_idx) * channel + chnlp_idx;
+
+        int src_offset = (batch_idx * area + area_idx) * inChannelPack + chnl_idx;
         output[index] = (T1)input[src_offset];
     }
 }
 
+// NCHW NHWC8
 template<typename T0, typename T1>
-__global__ void NHWC8_2_NCHW(const T0* input,
-    T1* output,
-    const int maxCount,
-    const int channel,
-    const int area,
-    const int channel_pack,
-    DivModFast d_oc,
-    DivModFast d_area
+__global__ void NCHW_2_NHWC8(const T0* input,
+                             T1* output,
+                             const int maxCount,
+                             const int channel,
+                             const int area,
+                             const int inChannelPack,
+                             DivModFast divOutChannelPack,
+                             DivModFast divArea
 ) {
     for(size_t index = blockIdx.x * blockDim.x + threadIdx.x; index < maxCount; index += blockDim.x * gridDim.x) {
+        int area_idx, temp, chnl_idx, batch_idx;
+        divOutChannelPack.divmod(index, temp, chnl_idx);
+        divArea.divmod(temp, batch_idx, area_idx);
 
-        int area_idx, temp, channel_idx, batch_idx;
-        d_area.divmod(index, temp, area_idx);
-        d_oc.divmod(temp, batch_idx, channel_idx);
+        if(chnl_idx >= channel) {
+            output[index] = (T1)0.0f;
+            continue;
+        }
 
-        int src_offset = (batch_idx * area + area_idx) * channel_pack + channel_idx;
+        int src_offset = (batch_idx * inChannelPack + chnl_idx) * area + area_idx;
         output[index] = (T1)input[src_offset];
     }
 }
 
+// C4NHW4 NHWC8
 template<typename T0, typename T1>
-__global__ void NHWC8_2_NHWC(const T0* input,
-    T1* output,
-    const int maxCount,
-    const int channel,
-    const int area,
-    const int channel_pack,
-    DivModFast d_oc,
-    DivModFast d_area
+__global__ void C4NHW4_2_NHWC8(const T0* input,
+                              T1* output,
+                              const int maxCount,
+                              const int channel,
+                              const int area,
+                             const int inChannelPack, // redundant parameter
+                              DivModFast divOutChannelPack,
+                              DivModFast divArea
 ) {
+    const int batch = (maxCount / (UP_DIV(channel, 8) * 8)) / area;
     for(size_t index = blockIdx.x * blockDim.x + threadIdx.x; index < maxCount; index += blockDim.x * gridDim.x) {
+        int area_idx, temp, chnl_idx, batch_idx;
+        divOutChannelPack.divmod(index, temp, chnl_idx);
+        divArea.divmod(temp, batch_idx, area_idx);
 
-        int area_idx, temp, channel_idx, batch_idx;
-        d_oc.divmod(index, temp, channel_idx);
-        d_area.divmod(temp, batch_idx, area_idx);
+        if(chnl_idx >= channel) {
+            output[index] = (T1)0.0f;
+            continue;
+        }
 
-        int src_offset = (batch_idx * area + area_idx) * channel_pack + channel_idx;
+        int c4_idx = chnl_idx >> 2;
+        int cL_idx = chnl_idx & 3;
+        int src_offset = ((c4_idx * batch + batch_idx) * area + area_idx) * 4 + cL_idx;
         output[index] = (T1)input[src_offset];
     }
 }
 
+// NHWC_2_C4NHW4
 template<typename T0, typename T1>
-__global__ void NCHW_2_NCHW(const T0* input,
-    T1* output,
-    const int maxCount
+__global__ void NHWC_2_C4NHW4(const T0* input,
+                               T1* output,
+                               const int maxCount,
+                               const int channel,
+                               const int area,
+                               const int inChannelPack,
+                               DivModFast divOutChannelPack,
+                               DivModFast divArea
 ) {
+    const int batch = (maxCount / (UP_DIV(channel, 4) * 4)) / area;
     for(size_t index = blockIdx.x * blockDim.x + threadIdx.x; index < maxCount; index += blockDim.x * gridDim.x) {
-        output[index] = (T1)input[index];
+        // arrange threads arrodring to NHWC4 format
+        int area_idx, temp, chnl_idx, batch_idx;
+        divOutChannelPack.divmod(index, temp, chnl_idx);
+        divArea.divmod(temp, batch_idx, area_idx);
+
+        int c4_idx = chnl_idx >> 2; // chnl_idx / 4
+        int cL_idx = chnl_idx & 3; // chnl_idx % 4
+        int dst_offset = ((c4_idx * batch + batch_idx) * area + area_idx) * 4 + cL_idx;
+        int src_offset = (batch_idx * area + area_idx) * inChannelPack + chnl_idx;
+
+        if (chnl_idx >= channel) {
+            output[dst_offset] = (T1)0.0f;;
+            continue;
+        }
+
+        output[dst_offset] = (T1)input[src_offset];
     }
 }
 
+// NCHW C4NHW4
 template<typename T0, typename T1>
-__global__ void C4NHW4_2_NHWC8(const T0* input,
-    T1* output,
-    const int maxCount,
-    const int batch,
-    const int area,
-    const int channel,
-    const int channel_pack
+__global__ void NCHW_2_C4NHW4(const T0* input,
+                               T1* output,
+                               const int maxCount,
+                               const int channel,
+                               const int area,
+                               const int inChannelPack,
+                               DivModFast divOutChannelPack,
+                               DivModFast divArea
 ) {
+    const int batch = (maxCount / (UP_DIV(channel, 4) * 4)) / area;
     for(size_t index = blockIdx.x * blockDim.x + threadIdx.x; index < maxCount; index += blockDim.x * gridDim.x) {
-        int c_idx = index % channel_pack;
-        int temp = index / channel_pack;
-        int hw_idx = temp % area;
-        int batch_idx = temp / area;
-
-	if(c_idx >= channel) {
-	    output[index] = (T1)0.0f;
-	    continue;
-	}
-        int c4_idx = c_idx >> 2;
-        int cL_idx = c_idx & 3;
-        output[index] = (T1)input[((c4_idx * batch + batch_idx) * area + hw_idx) * 4 + cL_idx];
+        // arrange threads arrodring to NHWC4 format
+        int area_idx, temp, chnl_idx, batch_idx;
+        divOutChannelPack.divmod(index, temp, chnl_idx);
+        divArea.divmod(temp, batch_idx, area_idx);
+
+        int c4_idx = chnl_idx >> 2; // chnl_idx / 4
+        int cL_idx = chnl_idx & 3; // chnl_idx % 4
+        int dst_offset = ((c4_idx * batch + batch_idx) * area + area_idx) * 4 + cL_idx;
+        int src_offset = (batch_idx * inChannelPack + chnl_idx) * area + area_idx;
+
+        if (chnl_idx >= channel) {
+            output[dst_offset] = (T1)0.0f;;
+            continue;
+        }
+
+        output[dst_offset] = (T1)input[src_offset];
     }
 }
 
+// NHWC8 C4NHW4
 template<typename T0, typename T1>
 __global__ void NHWC8_2_C4NHW4(const T0* input,
-    T1* output,
-    const int maxCount,
-    const int batch,
-    const int channel,
-    const int area,
-    const int channel_pack
+                               T1* output,
+                               const int maxCount,
+                               const int channel,
+                               const int area,
+                               const int inChannelPack,
+                               DivModFast divOutChannelPack,
+                               DivModFast divArea
 ) {
+    const int batch = (maxCount / (UP_DIV(channel, 4) * 4)) / area;
     for(size_t index = blockIdx.x * blockDim.x + threadIdx.x; index < maxCount; index += blockDim.x * gridDim.x) {
-        int c_idx = index % channel_pack;
-        int temp = index / channel_pack;
-        int hw_idx = temp % area;
-        int batch_idx = temp / area;
+        // arrange threads arrodring to NHWC4 format
+        int area_idx, temp, chnl_idx, batch_idx;
+        divOutChannelPack.divmod(index, temp, chnl_idx);
+        divArea.divmod(temp, batch_idx, area_idx);
+
+        int c4_idx = chnl_idx >> 2; // chnl_idx / 4
+        int cL_idx = chnl_idx & 3; // chnl_idx % 4
+        int dst_offset = ((c4_idx * batch + batch_idx) * area + area_idx) * 4 + cL_idx;
+        int src_offset = (batch_idx * area + area_idx) * inChannelPack + chnl_idx;;
 
-        int channel_8 = ((channel + 7) / 8) * 8;
-        int c4_idx = c_idx >> 2;
-        int cL_idx = c_idx & 3;
-        output[((c4_idx * batch + batch_idx) * area + hw_idx) * 4 + cL_idx] = 
-            (T1)input[(batch_idx * area + hw_idx) * channel_8 + c_idx];
+        output[dst_offset] = (T1)input[src_offset];
     }
 }
 
 template<class T0, class T1>
 static void insideFormatConvert(T0* input, T1* output, MNN_DATA_FORMAT srcDataFormat, MNN_DATA_FORMAT dstDataFormat, CUDARuntime* runtime, \
-    const int area, const int batch, const int channel) {
+    const int area, const int batch, const int channel, const bool srcDevice, const bool dstDevice) {
     DivModFast d_oc(channel);
-    DivModFast d_ocp(UP_DIV(channel, 8) * 8);
+    DivModFast d_oc4(UP_DIV(channel, 4) * 4);
+    DivModFast d_oc8(UP_DIV(channel, 8) * 8);
     DivModFast d_area(area);
 
-    if(srcDataFormat == MNN_DATA_FORMAT_NCHW && dstDataFormat == MNN_DATA_FORMAT_NC4HW4) {
-        const int maxCount = batch * area * UP_DIV(channel, 8) * 8;
-        const int block_num = runtime->blocks_num(maxCount);
-        const int block_size = runtime->threads_num();
-        NCHW_2_NHWC8<T0, T1><<<block_num, block_size>>>(input, output, maxCount, channel, area, UP_DIV(channel, 8) * 8,
-                            d_ocp, d_area);
-        checkKernelErrors;
-        return;
-    }
-    if(srcDataFormat == MNN_DATA_FORMAT_NHWC && dstDataFormat == MNN_DATA_FORMAT_NC4HW4) {
-        const int maxCount = batch * area * UP_DIV(channel, 8) * 8;
-        const int block_num = runtime->blocks_num(maxCount);
-        const int block_size = runtime->threads_num();
-        NHWC_2_NHWC8<T0, T1><<<block_num, block_size>>>(input, output, maxCount, channel, area, UP_DIV(channel, 8) * 8,
-                            d_ocp, d_area);
-        checkKernelErrors;
-        return;
-    }
-    if((srcDataFormat == MNN_DATA_FORMAT_NCHW && dstDataFormat == MNN_DATA_FORMAT_NCHW) || \
+    // NCHW NCHW
+    // NHWC NHWC
+    if ((srcDataFormat == MNN_DATA_FORMAT_NCHW && dstDataFormat == MNN_DATA_FORMAT_NCHW) || \
         (srcDataFormat == MNN_DATA_FORMAT_NHWC && dstDataFormat == MNN_DATA_FORMAT_NHWC)) {
         const int maxCount = batch * area * channel;
         const int block_num = runtime->blocks_num(maxCount);
@@ -531,168 +668,178 @@ static void insideFormatConvert(T0* input, T1* output, MNN_DATA_FORMAT srcDataFo
         checkKernelErrors;
         return;
     }
-    if(srcDataFormat == MNN_DATA_FORMAT_NC4HW4 && dstDataFormat == MNN_DATA_FORMAT_NCHW) {
-        const int maxCount = batch * area * channel;
-        const int block_num = runtime->blocks_num(maxCount);
-        const int block_size = runtime->threads_num();
-        NHWC8_2_NCHW<T0, T1><<<block_num, block_size>>>(input, output, maxCount, channel, area, UP_DIV(channel, 8) * 8,
-                            d_oc, d_area);
-        checkKernelErrors;
+
+    // NC4HW4 NC4HW4
+    if (srcDataFormat == MNN_DATA_FORMAT_NC4HW4 && dstDataFormat == MNN_DATA_FORMAT_NC4HW4) {
+        if(!srcDevice && dstDevice) {
+            const int maxCount = batch * area * UP_DIV(channel, 8) * 8;
+            const int block_num = runtime->blocks_num(maxCount);
+            const int block_size = runtime->threads_num();
+            C4NHW4_2_NHWC8<T0, T1><<<block_num, block_size>>>(input, output, maxCount, channel, area, UP_DIV(channel, 4) * 4, d_oc8, d_area);
+            checkKernelErrors;
+        } else if (srcDevice && !dstDevice) {
+            const int maxCount = batch * area * UP_DIV(channel, 4) * 4;
+            const int block_num = runtime->blocks_num(maxCount);
+            const int block_size = runtime->threads_num();
+            NHWC8_2_C4NHW4<T0, T1><<<block_num, block_size>>>(input, output, maxCount, channel, area, UP_DIV(channel, 8) * 8, d_oc4, d_area);
+            checkKernelErrors;
+        } else {
+            const int maxCount = batch * area * UP_DIV(channel, 8) * 8;
+            const int block_num = runtime->blocks_num(maxCount);
+            const int block_size = runtime->threads_num();
+            NCHW_2_NCHW<T0, T1><<<block_num, block_size>>>(input, output, maxCount);
+            checkKernelErrors;
+        }
         return;
     }
-    if(srcDataFormat == MNN_DATA_FORMAT_NC4HW4 && dstDataFormat == MNN_DATA_FORMAT_NHWC) {
+
+    // NHWC NCHW
+    if (srcDataFormat == MNN_DATA_FORMAT_NHWC && dstDataFormat == MNN_DATA_FORMAT_NCHW) {
         const int maxCount = batch * area * channel;
         const int block_num = runtime->blocks_num(maxCount);
         const int block_size = runtime->threads_num();
-        NHWC8_2_NHWC<T0, T1><<<block_num, block_size>>>(input, output, maxCount, channel, area, UP_DIV(channel, 8) * 8,
-                            d_oc, d_area);
+        NHWC_2_NCHW<T0, T1><<<block_num, block_size>>>(input, output, maxCount, channel, area, channel, d_oc, d_area);
         checkKernelErrors;
         return;
     }
-    if(srcDataFormat == MNN_DATA_FORMAT_NCHW && dstDataFormat == MNN_DATA_FORMAT_NHWC) {
+
+    // NC4HW4 NCHW
+    if (srcDataFormat == MNN_DATA_FORMAT_NC4HW4 && dstDataFormat == MNN_DATA_FORMAT_NCHW) {
+        if (!srcDevice) {
+            const int maxCount = batch * area * channel;
+            const int block_num = runtime->blocks_num(maxCount);
+            const int block_size = runtime->threads_num();
+            C4NHW4_2_NCHW<T0, T1><<<block_num, block_size>>>(input, output, maxCount, channel, area, UP_DIV(channel, 4) * 4, d_oc, d_area);
+            checkKernelErrors;
+        } else {
+            const int maxCount = batch * area * channel;
+            const int block_num = runtime->blocks_num(maxCount);
+            const int block_size = runtime->threads_num();
+            NHWC8_2_NCHW<T0, T1><<<block_num, block_size>>>(input, output, maxCount, channel, area, UP_DIV(channel, 8) * 8, d_oc, d_area);
+            checkKernelErrors;
+        }
+        return;
+    }
+
+    // NCHW NHWC
+    if (srcDataFormat == MNN_DATA_FORMAT_NCHW && dstDataFormat == MNN_DATA_FORMAT_NHWC) {
         const int maxCount = batch * area * channel;
         const int block_num = runtime->blocks_num(maxCount);
         const int block_size = runtime->threads_num();
-        NCHW_2_NHWC<T0, T1><<<block_num, block_size>>>(input, output, maxCount, channel, area, UP_DIV(channel, 8) * 8,
-                            d_oc, d_area);
+        NCHW_2_NHWC<T0, T1><<<block_num, block_size>>>(input, output, maxCount, channel, area, channel, d_oc, d_area);
         checkKernelErrors;
         return;
     }
-    MNN_PRINT("insideFormatConvert form %d to %d, not support\n", (int)srcDataFormat, (int)dstDataFormat);
-    
-}
-
-void FormatConvert(void* output, void* input, MNN_DATA_FORMAT srcDataFormat, MNN_DATA_FORMAT dstDataFormat, CUDARuntime* runtime, \
-    const int area, const int batch, const int channel, const Tensor* srcTensor, int precision, bool srcDevice, bool dstDevice) {
-
-    bool isFp16 = (precision == 2);
-    bool isBf16 = (precision == 3);
-    if(batch == 0 || area == 0 || channel == 0) {
-        MNN_PRINT("Error: formatConvert size batch:%d - plane:%d - channel:%d, format:%d->%d, device:%d->%d\n", batch, area, channel, srcDataFormat, dstDataFormat, srcDevice, dstDevice);
-        return;
-    }
 
-    auto des = TensorUtils::getDescribe(srcTensor);
-    if ((des->quantAttr.get() != nullptr && des->type == DataType_DT_INT8) || srcTensor->getType().bits == 8) {
-        if(srcDataFormat == MNN_DATA_FORMAT_NC4HW4 && dstDataFormat == MNN_DATA_FORMAT_NC4HW4) {
-            if(!srcDevice && dstDevice) {
-                const int maxCount = batch * area * UP_DIV(channel, 8) * 8;
-                const int block_num = runtime->blocks_num(maxCount);
-                const int block_size = runtime->threads_num();
-                C4NHW4_2_NHWC8<<<block_num, block_size>>>((int8_t *)input, (int8_t *)output, 
-                    maxCount, batch, area, channel, UP_DIV(channel, 8) * 8);
-                checkKernelErrors;
-                return;
-            }
-    
-            if(srcDevice && !dstDevice) {
-                const int maxCount = batch * area * UP_DIV(channel, 4) * 4;
-                const int block_num = runtime->blocks_num(maxCount);
-                const int block_size = runtime->threads_num();
-                NHWC8_2_C4NHW4<<<block_num, block_size>>>((int8_t *)input, (int8_t *)output, 
-                    maxCount, batch, channel, area, UP_DIV(channel, 4) * 4);
-                checkKernelErrors;
-                return;
-            }
+    // NC4HWC4 NHWC
+    if (srcDataFormat == MNN_DATA_FORMAT_NC4HW4 && dstDataFormat == MNN_DATA_FORMAT_NHWC) {
+        if (!srcDevice) {
+            const int maxCount = batch * area * channel;
+            const int block_num = runtime->blocks_num(maxCount);
+            const int block_size = runtime->threads_num();
+            C4NHW4_2_NHWC<T0, T1><<<block_num, block_size>>>(input, output, maxCount, channel, area, UP_DIV(channel, 4) * 4, d_oc, d_area);
+            checkKernelErrors;
+        } else {
+            const int maxCount = batch * area * channel;
+            const int block_num = runtime->blocks_num(maxCount);
+            const int block_size = runtime->threads_num();
+            NHWC8_2_NHWC<T0, T1><<<block_num, block_size>>>(input, output, maxCount, channel, area, UP_DIV(channel, 8) * 8, d_oc, d_area);
+            checkKernelErrors;
         }
-    
-        insideFormatConvert<int8_t, int8_t>((int8_t *)input, (int8_t *)output, srcDataFormat, dstDataFormat, runtime, area, batch, channel);
         return;
     }
 
-    isFp16 = isFp16 & (halide_type_float == srcTensor->getType().code);
-    isBf16 = isBf16 & (halide_type_float == srcTensor->getType().code);
-    if(srcDataFormat == MNN_DATA_FORMAT_NC4HW4 && dstDataFormat == MNN_DATA_FORMAT_NC4HW4) {
-        if(!srcDevice && dstDevice) {
+    // NCHW NC4HW4
+    if(srcDataFormat == MNN_DATA_FORMAT_NCHW && dstDataFormat == MNN_DATA_FORMAT_NC4HW4) {
+        if (!dstDevice) {
+            const int maxCount = batch * area * UP_DIV(channel, 4) * 4;
+            const int block_num = runtime->blocks_num(maxCount);
+            const int block_size = runtime->threads_num();
+            NCHW_2_C4NHW4<T0, T1><<<block_num, block_size>>>(input, output, maxCount, channel, area, channel, d_oc4, d_area);
+            checkKernelErrors;
+        } else {
             const int maxCount = batch * area * UP_DIV(channel, 8) * 8;
             const int block_num = runtime->blocks_num(maxCount);
             const int block_size = runtime->threads_num();
-            if(isFp16) {
-		        C4NHW4_2_NHWC8<<<block_num, block_size>>>((float *)input, (half *)output, 
-                    maxCount, batch, area, channel, UP_DIV(channel, 8) * 8);
-                checkKernelErrors;
-            } else if(isBf16) {
-                #ifdef ENABLE_CUDA_BF16
-		        C4NHW4_2_NHWC8<<<block_num, block_size>>>((float *)input, (__nv_bfloat16 *)output, 
-                    maxCount, batch, area, channel, UP_DIV(channel, 8) * 8);
-                checkKernelErrors;
-                #endif
-            } else {
-                C4NHW4_2_NHWC8<<<block_num, block_size>>>((float *)input, (float *)output, 
-                    maxCount, batch, area, channel, UP_DIV(channel, 8) * 8);
-                checkKernelErrors;
-            }
-            return;
+            NCHW_2_NHWC8<T0, T1><<<block_num, block_size>>>(input, output, maxCount, channel, area, channel, d_oc8, d_area);
+            checkKernelErrors;
         }
+        return;
+    }
 
-        if(srcDevice && !dstDevice) {
+    // NHWC NC4HW4
+    if(srcDataFormat == MNN_DATA_FORMAT_NHWC && dstDataFormat == MNN_DATA_FORMAT_NC4HW4) {
+        if (!dstDevice) {
             const int maxCount = batch * area * UP_DIV(channel, 4) * 4;
             const int block_num = runtime->blocks_num(maxCount);
             const int block_size = runtime->threads_num();
-            if(isFp16) {
-                NHWC8_2_C4NHW4<<<block_num, block_size>>>((half *)input, (float *)output, 
-                    maxCount, batch, channel, area, UP_DIV(channel, 4) * 4);
-                checkKernelErrors;
-            } else if(isBf16) {
-                #ifdef ENABLE_CUDA_BF16
-                NHWC8_2_C4NHW4<<<block_num, block_size>>>((__nv_bfloat16 *)input, (float *)output, 
-                    maxCount, batch, channel, area, UP_DIV(channel, 4) * 4);
-                checkKernelErrors;
-                #endif
-            } else {
-                NHWC8_2_C4NHW4<<<block_num, block_size>>>((float *)input, (float *)output, 
-                    maxCount, batch, channel, area, UP_DIV(channel, 4) * 4);
-                checkKernelErrors;
-            }
-            return;
-        }
-
-        if(srcDevice && dstDevice) {
+            NHWC_2_C4NHW4<T0, T1><<<block_num, block_size>>>(input, output, maxCount, channel, area, channel, d_oc4, d_area);
+            checkKernelErrors;
+        } else {
             const int maxCount = batch * area * UP_DIV(channel, 8) * 8;
             const int block_num = runtime->blocks_num(maxCount);
             const int block_size = runtime->threads_num();
-            if(isFp16 || isBf16) {
-                NCHW_2_NCHW<half, half><<<block_num, block_size>>>((half *)input, (half *)output, maxCount);
-                checkKernelErrors;
-            } else {
-                NCHW_2_NCHW<float, float><<<block_num, block_size>>>((float *)input, (float *)output, maxCount);
-                checkKernelErrors; 
-            }
-            return;
+            NHWC_2_NHWC8<T0, T1><<<block_num, block_size>>>(input, output, maxCount, channel, area, channel, d_oc8, d_area);
+            checkKernelErrors;
         }
+        return;
+    }
+
+    MNN_ERROR("CUDA backend doesn't support the format conversion.\n");
+    MNN_ASSERT(false);
+    return;
+}
+
+void FormatConvert(void* output, void* input, MNN_DATA_FORMAT srcDataFormat, MNN_DATA_FORMAT dstDataFormat, CUDARuntime* runtime, \
+    const int area, const int batch, const int channel, const Tensor* srcTensor, int precision, bool srcDevice, bool dstDevice) {
+    if(batch == 0 || area == 0 || channel == 0) {
+        MNN_PRINT("Error: formatConvert size batch:%d - plane:%d - channel:%d, format:%d->%d, device:%d->%d\n", batch, area, channel, srcDataFormat, dstDataFormat, srcDevice, dstDevice);
+        return;
+    }
+
+    bool isFp16 = (precision == 2) && (halide_type_float == srcTensor->getType().code);
+    bool isBf16 = (precision == 3) && (halide_type_float == srcTensor->getType().code);
+
+    // int8 case
+    auto des = TensorUtils::getDescribe(srcTensor);
+    if ((des->quantAttr.get() != nullptr && des->type == DataType_DT_INT8) || srcTensor->getType().bits == 8) {
+        insideFormatConvert<int8_t, int8_t>((int8_t *)input, (int8_t *)output, srcDataFormat, dstDataFormat, runtime, area, batch, channel, srcDevice, dstDevice);
+        return;
     }
 
+    // FP case
     if(!srcDevice) {
         if(isFp16) {
-            insideFormatConvert<float, half>((float *)input, (half *)output, srcDataFormat, dstDataFormat, runtime, area, batch, channel);
+            insideFormatConvert<float, half>((float *)input, (half *)output, srcDataFormat, dstDataFormat, runtime, area, batch, channel, srcDevice, dstDevice);
         } else if(isBf16) {
             #ifdef ENABLE_CUDA_BF16
-            insideFormatConvert<float, __nv_bfloat16>((float *)input, (__nv_bfloat16 *)output, srcDataFormat, dstDataFormat, runtime, area, batch, channel);
+            insideFormatConvert<float, __nv_bfloat16>((float *)input, (__nv_bfloat16 *)output, srcDataFormat, dstDataFormat, runtime, area, batch, channel, srcDevice, dstDevice);
             #endif
         } else {
-            insideFormatConvert<float, float>((float *)input, (float *)output, srcDataFormat, dstDataFormat, runtime, area, batch, channel);
+            insideFormatConvert<float, float>((float *)input, (float *)output, srcDataFormat, dstDataFormat, runtime, area, batch, channel, srcDevice, dstDevice);
         }
     } else if(!dstDevice) {
         if(isFp16) {
-            insideFormatConvert<half, float>((half *)input, (float *)output, srcDataFormat, dstDataFormat, runtime, area, batch, channel);
+            insideFormatConvert<half, float>((half *)input, (float *)output, srcDataFormat, dstDataFormat, runtime, area, batch, channel, srcDevice, dstDevice);
         } else if(isBf16) {
             #ifdef ENABLE_CUDA_BF16
-            insideFormatConvert<__nv_bfloat16, float>((__nv_bfloat16 *)input, (float *)output, srcDataFormat, dstDataFormat, runtime, area, batch, channel);
+            insideFormatConvert<__nv_bfloat16, float>((__nv_bfloat16 *)input, (float *)output, srcDataFormat, dstDataFormat, runtime, area, batch, channel, srcDevice, dstDevice);
             #endif
         } else {
-            insideFormatConvert<float, float>((float *)input, (float *)output, srcDataFormat, dstDataFormat, runtime, area, batch, channel);
+            insideFormatConvert<float, float>((float *)input, (float *)output, srcDataFormat, dstDataFormat, runtime, area, batch, channel, srcDevice, dstDevice);
         }
     } else {
         if(isFp16) {
-            insideFormatConvert<half, half>((half *)input, (half *)output, srcDataFormat, dstDataFormat, runtime, area, batch, channel);
+            insideFormatConvert<half, half>((half *)input, (half *)output, srcDataFormat, dstDataFormat, runtime, area, batch, channel, srcDevice, dstDevice);
         } else if(isBf16) {
             #ifdef ENABLE_CUDA_BF16
-            insideFormatConvert<__nv_bfloat16, __nv_bfloat16>((__nv_bfloat16 *)input, (__nv_bfloat16 *)output, srcDataFormat, dstDataFormat, runtime, area, batch, channel);
+            insideFormatConvert<__nv_bfloat16, __nv_bfloat16>((__nv_bfloat16 *)input, (__nv_bfloat16 *)output, srcDataFormat, dstDataFormat, runtime, area, batch, channel, srcDevice, dstDevice);
             #endif
         } else {
-            insideFormatConvert<float, float>((float *)input, (float *)output, srcDataFormat, dstDataFormat, runtime, area, batch, channel);
+            insideFormatConvert<float, float>((float *)input, (float *)output, srcDataFormat, dstDataFormat, runtime, area, batch, channel, srcDevice, dstDevice);
         }
     }
+    return;
 }
 
 

+ 1 - 1
source/backend/cuda/execution/bf16/ConvCutlassBf16Execution.cu

@@ -26,7 +26,7 @@ ConvCutlassBf16Execution::Resource::Resource(Backend* bn, const MNN::Op* op) {
     const float* filterDataPtr = nullptr;
     int weightSize = 0;
     std::shared_ptr<ConvolutionCommon::Int8Common> quanCommon;
-    ConvolutionCommon::getConvParameters(&quanCommon, conv, &filterDataPtr, &weightSize);
+    ConvolutionCommon::getConvParameters(&quanCommon, bn, conv, &filterDataPtr, &weightSize);
     auto oc = common->outputCount();
 
     int l = weightSize / oc;

+ 1 - 1
source/backend/cuda/execution/int8/ConvInt8CutlassExecution.cu

@@ -185,7 +185,7 @@ ConvInt8CutlassExecution::Resource::Resource(Backend* bn, const MNN::Op* op) {
     int weightSize = 0;
     std::shared_ptr<ConvolutionCommon::Int8Common> quanCommon;
 
-    bool res = ConvolutionCommon::getConvInt8Parameters(conv, quanCommon, filterDataPtr, weightSize, 
+    bool res = ConvolutionCommon::getConvInt8Parameters(conv, quanCommon, bn, filterDataPtr, weightSize, 
                                     mScaleFloatVec, 
                                     mBiasInt32Vec);
                                     // inputScale, 

+ 10 - 10
source/backend/hiai/backend/NPUBackend.cpp

@@ -389,11 +389,11 @@ namespace MNN {
         mSclipMap.clear();
     }
 
-    void NPUBackend::onResizeEnd() {
-        bulidIRModelAndLoad();
+    ErrorCode NPUBackend::onResizeEnd() {
+        return bulidIRModelAndLoad();
     }
 
-    void NPUBackend::bulidIRModelAndLoad() {
+    ErrorCode NPUBackend::bulidIRModelAndLoad() {
         std::vector<ge::Operator> inputs;
         for (auto input : mInputOps){
             inputs.push_back(input.second[0]);
@@ -414,7 +414,7 @@ namespace MNN {
         std::shared_ptr<ge::Model> model = std::make_shared<ge::Model>("model", graphName);
         if (model == nullptr) {
             MNN_ERROR("Create model fail.");
-            return;
+            return INVALID_VALUE;
         }
 
         model->SetGraph(graph);
@@ -431,7 +431,7 @@ namespace MNN {
             std::string buffer(size, ' ');
             if (!file.read(&buffer[0], size)) {
                 MNN_ERROR("Failed to read file.\n");
-                return;
+                return INVALID_VALUE;
             }
             file.close();
             buildOptions.quantizeConfig = buffer;
@@ -440,13 +440,13 @@ namespace MNN {
         auto ret = modelBuilder.Build(buildOptions, modelName, model, builtModel);
         if (ret != hiai::SUCCESS || builtModel == nullptr) {
             MNN_ERROR("model build fail !\n");
-            return;
+            return INVALID_VALUE;
         }
 #ifdef HIAI_DEBUG
         ret = builtModel->SaveToFile("/data/local/tmp/test_quant.om");
         if (ret != hiai::SUCCESS) {
             MNN_ERROR("builtModel SaveToFile failed\n");
-            return;
+            return INVALID_VALUE;
         }
 #endif
         modelManager = hiai::CreateModelManager();
@@ -454,12 +454,12 @@ namespace MNN {
         ret = modelManager->Init(initOptions, builtModel, nullptr);
         if (ret != hiai::SUCCESS) {
             MNN_ERROR("modelManager Init failed");
-            return;
+            return INVALID_VALUE;
         }
         ret = modelManager->SetPriority(hiai::ModelPriority::PRIORITY_HIGH);
         if (ret != hiai::SUCCESS) {
             MNN_ERROR("modelManager SetPriority failed");
-            return;
+            return INVALID_VALUE;
         }
         std::vector<hiai::NDTensorDesc> inputDesc = builtModel->GetInputTensorDescs();
         for (size_t i = 0; i < inputDesc.size(); i++) {
@@ -478,7 +478,7 @@ namespace MNN {
                 index++;
             }
         }
-        return;
+        return NO_ERROR;
     }
 
     int NPUBackend::process() const {

+ 3 - 2
source/backend/hiai/backend/NPUBackend.hpp

@@ -18,6 +18,7 @@
 #include <graph/compatible/all_ops.h>
 #include <hiai_ir_build.h>
 #include <graph/buffer.h>
+#include <MNN/ErrorCode.hpp>
 #include <core/Backend.hpp>
 #include <core/Execution.hpp>
 #include "HiAiModelManagerService.h"
@@ -267,11 +268,11 @@ namespace MNN {
         virtual void onCopyBuffer(const Tensor* srcTensor, const Tensor* dstTensor) const override;
 
         virtual void onResizeBegin() override;
-        virtual void onResizeEnd() override;
+        virtual ErrorCode onResizeEnd() override;
 
     public:
 
-        void bulidIRModelAndLoad();
+        ErrorCode bulidIRModelAndLoad();
         int process() const ;
 
         shared_ptr<ge::Operator> getInputOps(const Op *op, int index = 0);

+ 1 - 1
source/backend/hiai/execution/NPUConvolution.cpp

@@ -36,7 +36,7 @@ ErrorCode NPUConvolution::onResize(const std::vector<Tensor *> &inputs, const st
 
     std::shared_ptr<MNN::ConvolutionCommon::Int8Common> quanCommon;
     if (nullptr != conv2D->quanParameter()) {
-        quanCommon = ConvolutionCommon::load(conv2D->quanParameter(), true);
+        quanCommon = ConvolutionCommon::load(conv2D, backend(), true);
         if (nullptr == quanCommon) {
             MNN_ERROR("Memory not Enough, can't extract IDST Convolution: %s \n", mOp->name()->c_str());
         }

+ 1 - 1
source/backend/hiai/execution/NPUConvolutionDepthwise.cpp

@@ -36,7 +36,7 @@ ErrorCode NPUConvolutionDepthwise::onResize(const std::vector<Tensor *> &inputs,
 
     std::shared_ptr<MNN::ConvolutionCommon::Int8Common> quanCommon;
     if (nullptr != conv2D->quanParameter()) {
-        quanCommon = ConvolutionCommon::load(conv2D->quanParameter(), true);
+        quanCommon = ConvolutionCommon::load(conv2D, backend(), true);
         if (nullptr == quanCommon) {
             MNN_ERROR("Memory not Enough, can't extract IDST Convolution: %s \n", mOp->name()->c_str());
         }

+ 2 - 1
source/backend/metal/MetalBackend.hpp

@@ -14,6 +14,7 @@
 #include "core/TensorUtils.hpp"
 #include "MNN_generated.h"
 #include "MetalDefine.h"
+#include <MNN/ErrorCode.hpp>
 #include <vector>
 //#include "MNNMetalContext.h"
 #include "MetalCache_generated.h"
@@ -141,7 +142,7 @@ public:
                                 const MNN::Op *op) override;
     
     virtual void onResizeBegin() override;
-    virtual void onResizeEnd() override;
+    virtual ErrorCode onResizeEnd() override;
     virtual void onExecuteBegin() const override;
     virtual void onExecuteEnd() const override;
     virtual int onSync(Tensor::MapType mtype, bool toCpu, const Tensor* dstTensor) override;

+ 2 - 1
source/backend/metal/MetalBackend.mm

@@ -355,9 +355,10 @@ void MetalBackend::onResizeBegin() {
     [ctx wait];
 }
 
-void MetalBackend::onResizeEnd() {
+ErrorCode MetalBackend::onResizeEnd() {
     auto ctx = (__bridge MNNMetalContext *)context();
     mFrameEncodeCache = (!ctx.isCommitEachShader && mOpFullSupport);
+    return NO_ERROR;
 }
 
 void MetalBackend::onCopyHostToDevice(const Tensor *src, const Tensor *dst) const {

+ 1 - 1
source/backend/metal/MetalConvolutionCommon.mm

@@ -94,7 +94,7 @@ static id<MTLBuffer> weightInBlock(MNNMetalContext *context, int group, int oc,
 void MetalConvolutionCommon::loadWeight(const MNN::Convolution2D *conv) {
     std::shared_ptr<ConvolutionCommon::Int8Common> qnt = NULL;
     if (conv->quanParameter()) {
-        qnt          = ConvolutionCommon::load(conv->quanParameter(), true);
+        qnt          = ConvolutionCommon::load(conv, backend(), true);
     }
     mWeight = weightForConv(conv, qnt.get(), mDepthwise);
 }

+ 1 - 1
source/backend/metal/MetalDeconvolution.mm

@@ -139,7 +139,7 @@ MetalDeconvolution::MetalDeconvolution(Backend *backend, const MNN::Op *op) : Ex
     // forcy downgrade to float like what CPU does
     std::shared_ptr<ConvolutionCommon::Int8Common> qnt = NULL;
     if (deconv->quanParameter()) {
-        qnt = ConvolutionCommon::load(deconv->quanParameter(), true);
+        qnt = ConvolutionCommon::load(deconv, backend, true);
     }
     mWeight = weightForDeconv(context, mDepthwise, deconv, qnt.get());
     mBias   = biasForDeconv(context, deconv);

+ 3 - 2
source/backend/nnapi/backend/NNAPIBackend.cpp

@@ -252,7 +252,7 @@ namespace MNN {
         mQuantCacheMap.clear();
     }
 
-    void NNAPIBackend::onResizeEnd() {
+    ErrorCode NNAPIBackend::onResizeEnd() {
         buildModel();
         mHalfBuffer.clear();
         mQuantCacheMap.clear();
@@ -453,7 +453,7 @@ namespace MNN {
         return NO_ERROR;
     }
 
-    void NNAPIBackend::buildModel() {
+    ErrorCode NNAPIBackend::buildModel() {
         // set input and output of model
         std::vector<uint32_t> inputOperands(mInputTensors.size()), outputOperands(mOutputTensors.size());
         for (int i = 0; i < mInputTensors.size(); i++) {
@@ -503,6 +503,7 @@ namespace MNN {
         CHECK(ANeuralNetworksCompilation_setPreference_27, mNNAPICompilation, ANEURALNETWORKS_PREFER_SUSTAINED_SPEED);
         CHECK(ANeuralNetworksCompilation_finish_27, mNNAPICompilation);
         CHECK(ANeuralNetworksBurst_create_29, mNNAPICompilation, &mNNAPIBurst);
+        return NO_ERROR;
     }
 
     void NNAPIBackend::invokeModel() const {

+ 3 - 2
source/backend/nnapi/backend/NNAPIBackend.hpp

@@ -12,6 +12,7 @@
 #include <stdio.h>
 #include <map>
 #include <memory>
+#include <MNN/ErrorCode.hpp>
 #include <core/Backend.hpp>
 #include <core/Execution.hpp>
 #include <core/TensorUtils.hpp>
@@ -80,7 +81,7 @@ namespace MNN {
         virtual void onCopyBuffer(const Tensor* srcTensor, const Tensor* dstTensor) const override;
 
         virtual void onResizeBegin() override;
-        virtual void onResizeEnd() override;
+        virtual ErrorCode onResizeEnd() override;
 
     public:
         class Creator {
@@ -118,7 +119,7 @@ namespace MNN {
         ErrorCode buildQuantOperation(const Tensor* src, const Tensor* dst);
         ErrorCode replaceTensorWith(const Tensor* src, const Tensor* replace);
         uint32_t buildDequantOperand(const Tensor* t);
-        void buildModel();
+        ErrorCode buildModel();
         void invokeModel() const;
     private:
         bool mNCHW = false;

+ 1 - 1
source/backend/nnapi/execution/NNAPIConvolution.cpp

@@ -81,7 +81,7 @@ ErrorCode NNAPIConvolution::onResize(const std::vector<Tensor *> &inputs, const
             weightPtr = conv2D->quanParameter()->buffer()->data();
             weightSize = conv2D->quanParameter()->buffer()->size();
         } else if (nullptr != conv2D->quanParameter()) {
-            quanCommon = ConvolutionCommon::load(conv2D->quanParameter(), true);
+            quanCommon = ConvolutionCommon::load(conv2D, backend(), true);
             if (nullptr == quanCommon) {
                 MNN_ERROR("Memory not Enough, can't extract IDST Convolution: %s \n", mOp->name()->c_str());
             }

+ 2 - 1
source/backend/opencl/core/OpenCLBackend.cpp

@@ -505,11 +505,12 @@ void OpenCLBackend::onResizeBegin() {
     mOpenCLRuntime->releaseRecord();
 }
 
-void OpenCLBackend::onResizeEnd() {
+ErrorCode OpenCLBackend::onResizeEnd() {
 #ifndef ENABLE_OPENCL_TIME_PROFILER
     mOpenCLRuntime->setCommandQueueProfileDisable();
 #endif
     mOpenCLRuntime->endRecord();
+    return NO_ERROR;
 }
 
 void OpenCLBackend::onExecuteBegin() const {

+ 2 - 1
source/backend/opencl/core/OpenCLBackend.hpp

@@ -11,6 +11,7 @@
 
 #include "core/Backend.hpp"
 #include "MNN_generated.h"
+#include <MNN/ErrorCode.hpp>
 
 #include <list>
 #include <vector>
@@ -102,7 +103,7 @@ public:
                                 const MNN::Op *op) override;
 
     virtual void onResizeBegin() override;
-    virtual void onResizeEnd() override;
+    virtual ErrorCode onResizeEnd() override;
 
     virtual void onExecuteBegin() const override;
     virtual void onExecuteEnd() const override;

+ 1 - 1
source/backend/opencl/execution/buffer/BinaryBufExecution.cpp

@@ -315,7 +315,7 @@ public:
                 case BinaryOpOperation_SquaredDifference:
                     return new BinaryBufExecution(inputs, "(in0-in1)*(in0-in1)", op, backend);
                 case BinaryOpOperation_ATAN2:
-                    return new BinaryBufExecution(inputs, "atan(sign(in1)*in0/(fabs(in1)>(FLOAT4)((FLOAT)0.0000001)?fabs(in1):(FLOAT4)((FLOAT)0.0000001)))", op, backend);
+                    return new BinaryBufExecution(inputs, "(in1==(FLOAT4)0?(sign(in0)*(FLOAT4)(PI/2)):(atan(in0/in1)+(in1>(FLOAT4)0?(FLOAT4)0:sign(in0)*(FLOAT4)PI)))", op, backend);
                 case BinaryOpOperation_NOTEQUAL:
                     return new BinaryBufExecution(inputs, "convert_float4(-isnotequal(in0,in1))", op, backend);
                 case BinaryOpOperation_MOD:

+ 2 - 2
source/backend/opencl/execution/buffer/ConvBufExecution.cpp

@@ -325,7 +325,7 @@ ConvBufExecution::ConvBufExecution(const std::vector<Tensor *> &inputs, const st
         mRasterExe.reset(new RasterBufExecution({mFilter.get()}, op, mOpenCLBackend));
     } else {
         int weightSize   = 0;
-        ConvolutionCommon::getConvParameters(&quanCommon, conv2dParams, &mFilterDataPtr, &weightSize);
+        ConvolutionCommon::getConvParameters(&quanCommon, backend, conv2dParams, &mFilterDataPtr, &weightSize);
         //select opt conv method
         mConv1x1Opt = (mKernelHeight == mKernelWidth && mKernelHeight == 1 && mPaddings[0] == 0 &&
         mPaddings[1] == 0 && mStrides[0] == 1 && mStrides[1] == 1 && inputs[0]->width() >= 4);
@@ -517,7 +517,7 @@ ErrorCode ConvBufExecution::onResize(const std::vector<Tensor *> &inputs, const
         int min_index  = min_cost.second;
         if(min_index >= c8_index_start) {//if best kernel is "conv_2d_1x1_c8h1w4", set weight packCout to 8
             int weightSize   = 0;
-            ConvolutionCommon::getConvParameters(&quanCommon, mConv2dParams, &mFilterDataPtr, &weightSize);
+            ConvolutionCommon::getConvParameters(&quanCommon, backend(), mConv2dParams, &mFilterDataPtr, &weightSize);
             setConv1x1WeightBuffer(8, 4, mFilterDataPtr);
         }
         mGlobalWorkSize = {globalWorkSize[min_index][0], globalWorkSize[min_index][1]};

+ 1 - 1
source/backend/opencl/execution/buffer/ConvBufWinograd.cpp

@@ -54,7 +54,7 @@ ConvBufWinograd::ConvBufWinograd(const MNN::Convolution2D* op, Backend* backend)
     int weightSize             = 0;
     const float* filterDataPtr = nullptr;
     std::shared_ptr<ConvolutionCommon::Int8Common> quanCommon;
-    ConvolutionCommon::getConvParameters(&quanCommon, op, &filterDataPtr, &weightSize);
+    ConvolutionCommon::getConvParameters(&quanCommon, backend, op, &filterDataPtr, &weightSize);
 
     int oc     = mCommon->outputCount();
     int ic     = weightSize / oc / mCommon->kernelX() / mCommon->kernelY();

+ 1 - 1
source/backend/opencl/execution/buffer/ConvSubgroupBufExecution.cpp

@@ -112,7 +112,7 @@ ConvSubgroupBuf::ConvSubgroupBuf(const std::vector<Tensor *> &inputs, const std:
         const float *FilterDataPtr = NULL;
         int weightSize = 0;
         std::shared_ptr<ConvolutionCommon::Int8Common> quanCommon;
-        ConvolutionCommon::getConvParameters(&quanCommon, conv2dParams, &FilterDataPtr, &weightSize);
+        ConvolutionCommon::getConvParameters(&quanCommon, backend, conv2dParams, &FilterDataPtr, &weightSize);
         if (FilterDataPtr != nullptr) {
             std::shared_ptr<Tensor> sourceWeight(
                 Tensor::create<float>(std::vector<int>{mOutputChannel, mInputChannel, mKernelWidth, mKernelHeight},

+ 1 - 1
source/backend/opencl/execution/buffer/DeconvBufExecution.cpp

@@ -35,7 +35,7 @@ DeconvBufExecution::DeconvBufExecution(const std::vector<Tensor *> &inputs, cons
     const float* filterDataPtr = nullptr;
     int weightSize = 0;
     std::shared_ptr<ConvolutionCommon::Int8Common> quanCommon;
-    ConvolutionCommon::getConvParameters(&quanCommon, conv2dParams, &filterDataPtr, &weightSize);
+    ConvolutionCommon::getConvParameters(&quanCommon, backend, conv2dParams, &filterDataPtr, &weightSize);
 
     int inputChannel  = weightSize / (kernelWidth * kernelHeight * outputChannel);
     std::vector<int> filterShape{outputChannel, inputChannel, kernelHeight, kernelWidth};

+ 1 - 1
source/backend/opencl/execution/buffer/DepthwiseConvBufExecution.cpp

@@ -35,7 +35,7 @@ DepthwiseConvBufExecution::DepthwiseConvBufExecution(const std::vector<Tensor *>
     const float* filterDataPtr = nullptr;
     int filterDataSize   = 0;
     std::shared_ptr<ConvolutionCommon::Int8Common> quanCommon;
-    ConvolutionCommon::getConvParameters(&quanCommon, mCon2dParams, &filterDataPtr, &filterDataSize);
+    ConvolutionCommon::getConvParameters(&quanCommon, backend, mCon2dParams, &filterDataPtr, &filterDataSize);
 
     mFilter.reset(Tensor::createDevice<float>({1, ROUND_UP(filterImageShape[1], 2)/*for kernel C8 read*/, 1, 4 * filterImageShape[0]}));
     std::shared_ptr<Tensor> filterBuffer(Tensor::createDevice<float>(filterShape));

+ 1 - 1
source/backend/opencl/execution/buffer/DepthwiseConvSubgroupBufExecution.cpp

@@ -40,7 +40,7 @@ DepthwiseConvSubgroupBufExecution::DepthwiseConvSubgroupBufExecution(const std::
         const float *filterDataPtr = nullptr;
         int filterDataSize         = 0;
         std::shared_ptr<ConvolutionCommon::Int8Common> quanCommon;
-        ConvolutionCommon::getConvParameters(&quanCommon, mCon2dParams, &filterDataPtr, &filterDataSize);
+        ConvolutionCommon::getConvParameters(&quanCommon, backend, mCon2dParams, &filterDataPtr, &filterDataSize);
         if (filterDataPtr != nullptr) {
             std::shared_ptr<Tensor> sourceWeight(Tensor::create<float>(
                 std::vector<int>{1, outputChannel, kernelWidth, kernelHeight},

+ 121 - 11
source/backend/opencl/execution/buffer/LoopBufExecution.cpp

@@ -17,7 +17,10 @@ namespace OpenCL {
 
 static void _TileOrPackTensor(Tensor *input, Tensor *output, cl::Kernel& kernel, cl::NDRange &globalWorkSize,
                         cl::NDRange &localWorkSize, const int Width, const int Height, const int Channel,
-                        const int Batch, OpenCLRuntime *runTime, const std::string &KernelName, const std::set<std::string> &buildOptions) {
+                        const int Batch, OpenCLRuntime *runTime, const std::string &KernelName, std::set<std::string> buildOptions) {
+    if (TensorUtils::getDescribe(output)->dimensionFormat == MNN::MNN_DATA_FORMAT_NHWC || TensorUtils::getDescribe(input)->dimensionFormat == MNN::MNN_DATA_FORMAT_NHWC){
+        buildOptions.emplace("-DMNN_NHWC");
+    }
     kernel = runTime->buildKernel("loop_buf", KernelName, buildOptions);
     uint32_t mMaxWorkGroupSize  = static_cast<uint32_t>(runTime->getMaxWorkGroupSize(kernel));
     std::vector<uint32_t> mGlobalWorkSize = {(uint32_t)(Width * Height), (uint32_t)(UP_DIV(Channel, 4)), (uint32_t)(Batch)};
@@ -92,7 +95,7 @@ static void _setTensorStack(std::vector<Tensor *> &result, const std::vector<Ten
         const int Width = Shape.at(2);
         const int Height = Shape.at(1);
         const int Batch = Shape.at(0);
-        mTmpTensors[1] = std::make_shared<Tensor>(Tensor::createDevice<float>(std::vector<int>{Batch, Channel, Height, Width}));
+        mTmpTensors[1] = std::make_shared<Tensor>(Tensor::createDevice<float>(std::vector<int>{Batch, Channel, Height, Width}, Tensor::CAFFE));
         mOpenCLBackend->onAcquireBuffer(mTmpTensors[1].get(), Backend::DYNAMIC);
 
         Unit unit;
@@ -108,7 +111,7 @@ static void _setTensorStack(std::vector<Tensor *> &result, const std::vector<Ten
             const int Width = Shape.at(2);
             const int Height = Shape.at(1);
             const int Batch = Shape.at(0);
-            mOffsetTensors.emplace_back(std::make_shared<Tensor>(Tensor::createDevice<float>(std::vector<int>{Batch, Channel, Height, Width})));
+            mOffsetTensors.emplace_back(std::make_shared<Tensor>(Tensor::createDevice<float>(std::vector<int>{Batch, Channel, Height, Width}, Tensor::CAFFE)));
             mOpenCLBackend->onAcquireBuffer(mOffsetTensors.back().get(), Backend::DYNAMIC);
 
             Unit unit;
@@ -119,13 +122,13 @@ static void _setTensorStack(std::vector<Tensor *> &result, const std::vector<Ten
      
      // gather
      {
-        mTmpTensors[0] = std::make_shared<Tensor>(Tensor::createDevice<float>(std::vector<int>{n, z, y, x}));
+        mTmpTensors[0] = std::make_shared<Tensor>(Tensor::createDevice<float>(std::vector<int>{n, z, y, x}, Tensor::CAFFE));
         mOpenCLBackend->onAcquireBuffer(mTmpTensors[0].get(), Backend::DYNAMIC);
         int offset_index = 0;
 
         Unit unit;
-        std::string KernelName = "batch_gather_buf";
-        unit.kernel = runTime->buildKernel("loop_buf", KernelName, mBuildOptions);
+        std::string KernelName = "batch_gather";
+        unit.kernel = runTime->buildKernel("loop", KernelName, mBuildOptions);
         uint32_t mMaxWorkGroupSize = static_cast<uint32_t>(runTime->getMaxWorkGroupSize(unit.kernel));
         std::vector<uint32_t> mGlobalWorkSize = {(uint32_t)(x * y), (uint32_t)(z), (uint32_t)(n)};
 
@@ -222,7 +225,7 @@ ErrorCode LoopBatchMatMulBufExecution::onResize(const std::vector<Tensor *> &inp
         const int Width = Shape.at(2);
         const int Height = Shape.at(1);
         const int Batch = Shape.at(0);
-        mTmpTensors[i] = std::make_shared<Tensor>(Tensor::createDevice<float>(std::vector<int>{Batch, Channel, Height, Width}));
+        mTmpTensors[i] = std::make_shared<Tensor>(Tensor::createDevice<float>(std::vector<int>{Batch, Channel, Height, Width}, Tensor::CAFFE));
         mOpenCLBackend->onAcquireBuffer(mTmpTensors[i].get(), Backend::DYNAMIC);       
 
         Unit unit;
@@ -238,7 +241,7 @@ ErrorCode LoopBatchMatMulBufExecution::onResize(const std::vector<Tensor *> &inp
             const int Width = Shape.at(2);
             const int Height = Shape.at(1);
             const int Batch = Shape.at(0);
-            mOffsetTensors.emplace_back(std::make_shared<Tensor>(Tensor::createDevice<float>(std::vector<int>{Batch, Channel, Height, Width})));
+            mOffsetTensors.emplace_back(std::make_shared<Tensor>(Tensor::createDevice<float>(std::vector<int>{Batch, Channel, Height, Width}, Tensor::CAFFE)));
             mOpenCLBackend->onAcquireBuffer(mOffsetTensors.back().get(), Backend::DYNAMIC);
 
             Unit unit;
@@ -249,12 +252,12 @@ ErrorCode LoopBatchMatMulBufExecution::onResize(const std::vector<Tensor *> &inp
 
      // matmul
      {
-        mTmpTensors[0] = std::make_shared<Tensor>(Tensor::createDevice<float>(std::vector<int>{1, n, e, h}));
+        mTmpTensors[0] = std::make_shared<Tensor>(Tensor::createDevice<float>(std::vector<int>{1, n, e, h}, Tensor::CAFFE));
         mOpenCLBackend->onAcquireBuffer(mTmpTensors[0].get(), Backend::DYNAMIC);
         int offset_index = 0;
 
         Unit unit;
-        std::string KernelName = "batch_matmul_buf";
+        std::string KernelName = "batch_matmul";
         if (mHasBias) {
             mBuildOptions.emplace("-DBIAS");
         }
@@ -264,7 +267,7 @@ ErrorCode LoopBatchMatMulBufExecution::onResize(const std::vector<Tensor *> &inp
         if (mTransposeB) {
             mBuildOptions.emplace("-DTRANSPOSE_B");
         }
-        unit.kernel = runTime->buildKernel("loop_buf", KernelName, mBuildOptions);
+        unit.kernel = runTime->buildKernel("loop", KernelName, mBuildOptions);
         uint32_t mMaxWorkGroupSize = static_cast<uint32_t>(runTime->getMaxWorkGroupSize(unit.kernel));
         std::vector<uint32_t> mGlobalWorkSize = {(uint32_t)(h), (uint32_t)(e),(uint32_t)(n)};
 
@@ -324,6 +327,70 @@ ErrorCode LoopBatchMatMulBufExecution::onResize(const std::vector<Tensor *> &inp
     return NO_ERROR;
 }
 
+LoopBinaryBufExecution::LoopBinaryBufExecution(const LoopParam *loop, const std::string &compute, const MNN::Op *op, Backend *bn)
+    : CommonExecution(bn, op) {
+    mLoop = loop;
+    mTensors.resize(mLoop->tensorNumber());
+    auto cmd = loop->commands()->GetAs<RegionCommand>(0);
+    mBuildOptions.emplace("-DLOOP_BINARY_OPERATOR=" + compute);
+}
+
+ErrorCode LoopBinaryBufExecution::onResize(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) {
+    auto cmd                      = mLoop->commands()->GetAs<RegionCommand>(0);
+    OpenCLBackend *mOpenCLBackend = (OpenCLBackend *)backend();
+    auto runTime                  = mOpenCLBackend->getOpenCLRuntime();
+    _setTensorStack(mTensors, inputs, outputs, mLoop);
+    mUnits.clear();
+    
+    Unit unit;
+    auto input0 = mTensors[cmd->indexes()->data()[1]];
+    std::vector<int> Input0Shape = tensorShapeFormat(input0);
+    int Input0Size[4] = {Input0Shape.at(2), Input0Shape.at(1),Input0Shape.at(3),Input0Shape.at(0)};
+         
+    auto input1 = mTensors[cmd->indexes()->data()[2]];
+    std::vector<int> Input1Shape = tensorShapeFormat(input1);
+    int Input1Size[4] = {Input1Shape.at(2), Input1Shape.at(1),Input1Shape.at(3),Input1Shape.at(0)};
+         
+    auto output = mTensors[cmd->indexes()->data()[0]];
+    std::vector<int> Shape = tensorShapeFormat(output);
+    const int Channel = Shape.at(3);
+    const int Width = Shape.at(2);
+    const int Height = Shape.at(1);
+    const int Batch = Shape.at(0);
+    const int ChannelBlock = UP_DIV(Channel, 4);
+    auto BuildOptions = mBuildOptions;
+    if(Input0Size[2] != Input1Size[2]){
+        BuildOptions.emplace("-DBROADCAST_CHANNEL");
+    }
+    std::string KernelName = "broadcast_binary_buf";
+    unit.kernel = runTime->buildKernel("loop_buf", KernelName, BuildOptions);
+    uint32_t mMaxWorkGroupSize = static_cast<uint32_t>(runTime->getMaxWorkGroupSize(unit.kernel));
+
+    std::vector<uint32_t> mGlobalWorkSize = {(uint32_t)(Width), (uint32_t)(Height), (uint32_t)(Batch * ChannelBlock)};
+
+    uint32_t index = 0;
+    cl_int ret = CL_SUCCESS;
+    ret |= unit.kernel.setArg(index++, mGlobalWorkSize[0]);
+    ret |= unit.kernel.setArg(index++, mGlobalWorkSize[1]);
+    ret |= unit.kernel.setArg(index++, mGlobalWorkSize[2]);
+    ret |= unit.kernel.setArg(index++, openCLBuffer(output));
+    ret |= unit.kernel.setArg(index++, openCLBuffer(input0));
+    ret |= unit.kernel.setArg(index++, openCLBuffer(input1));
+    ret |= unit.kernel.setArg(index++, sizeof(Input0Size), Input0Size);
+    ret |= unit.kernel.setArg(index++, sizeof(Input1Size), Input1Size);
+    ret |= unit.kernel.setArg(index++, Width);
+    ret |= unit.kernel.setArg(index++, Height);
+    ret |= unit.kernel.setArg(index++, ChannelBlock);
+    MNN_CHECK_CL_SUCCESS(ret, "setArg LoopBinaryBufExecution");
+
+    std::vector<uint32_t> mLocalWorkSize = localWS3DDefault(mGlobalWorkSize, mMaxWorkGroupSize, runTime, KernelName, unit.kernel).first;
+
+    unit.globalWorkSize = {mGlobalWorkSize[0], mGlobalWorkSize[1], mGlobalWorkSize[2]};
+    unit.localWorkSize  = {mLocalWorkSize[0], mLocalWorkSize[1], mLocalWorkSize[2]};
+    mUnits.emplace_back(unit);
+    return NO_ERROR;
+}
+
 class LoopBufCreator : public OpenCLBackend::Creator {
 public:
     virtual Execution *onCreate(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs,
@@ -351,6 +418,49 @@ public:
             if (OpType_MatMul == subop->type() && loop->parallel()) {
                 return new LoopBatchMatMulBufExecution(loop, op, backend);
             }
+            if (OpType_BinaryOp == subop->type() && loop->parallel()) {
+                switch (subop->main_as_BinaryOp()->opType()) {
+                    case BinaryOpOperation_MUL:
+                        return new LoopBinaryBufExecution(loop, "in0*in1", op, backend);
+                    case BinaryOpOperation_ADD:
+                        return new LoopBinaryBufExecution(loop, "in0+in1", op, backend);
+                    case BinaryOpOperation_SUB:
+                        return new LoopBinaryBufExecution(loop, "in0-in1", op, backend);
+                    case BinaryOpOperation_REALDIV:
+                        return new LoopBinaryBufExecution(loop, "sign(in1)*in0/(fabs(in1)>(FLOAT4)((FLOAT)0.0000001)?fabs(in1):(FLOAT4)((FLOAT)0.0000001))", op, backend);
+                    case BinaryOpOperation_MINIMUM:
+                        return new LoopBinaryBufExecution(loop, "in0>in1?in1:in0", op, backend);
+                    case BinaryOpOperation_MAXIMUM:
+                        return new LoopBinaryBufExecution(loop, "in0>in1?in0:in1", op, backend);
+                    case BinaryOpOperation_GREATER:
+                        return new LoopBinaryBufExecution(loop, "convert_float4(-isgreater(in0,in1))", op, backend);
+                    case BinaryOpOperation_LESS:
+                        return new LoopBinaryBufExecution(loop, "convert_float4(-isless(in0,in1))", op, backend);
+                    case BinaryOpOperation_LESS_EQUAL:
+                        return new LoopBinaryBufExecution(loop, "convert_float4(-islessequal(in0,in1))", op, backend);
+                    case BinaryOpOperation_GREATER_EQUAL:
+                        return new LoopBinaryBufExecution(loop, "convert_float4(-isgreaterequal(in0,in1))", op, backend);
+                    case BinaryOpOperation_EQUAL:
+                        return new LoopBinaryBufExecution(loop, "convert_float4(-isequal(in0,in1))", op, backend);
+                    case BinaryOpOperation_FLOORDIV:
+                        return new LoopBinaryBufExecution(loop, "floor(sign(in1)*in0/(fabs(in1)>(FLOAT4)((FLOAT)0.0000001)?fabs(in1):(FLOAT4)((FLOAT)0.0000001)))", op, backend);
+                    case BinaryOpOperation_FLOORMOD:
+                        return new LoopBinaryBufExecution(loop, "in0-floor(sign(in1)*in0/(fabs(in1)>(FLOAT4)((FLOAT)0.0000001)?fabs(in1):(FLOAT4)((FLOAT)0.0000001)))*in1", op, backend);
+                    case BinaryOpOperation_POW:
+                        return new LoopBinaryBufExecution(loop, "pow(in0,in1)", op, backend);
+                    case BinaryOpOperation_SquaredDifference:
+                        return new LoopBinaryBufExecution(loop, "(in0-in1)*(in0-in1)", op, backend);
+                    case BinaryOpOperation_ATAN2:
+                        return new LoopBinaryBufExecution(loop, "atan(sign(in1)*in0/(fabs(in1)>(FLOAT4)((FLOAT)0.0000001)?fabs(in1):(FLOAT4)((FLOAT)0.0000001)))", op, backend);
+                    case BinaryOpOperation_NOTEQUAL:
+                        return new LoopBinaryBufExecution(loop, "convert_float4(-isnotequal(in0,in1))", op, backend);
+                    case BinaryOpOperation_MOD:
+                        return new LoopBinaryBufExecution(loop, "in0-floor(sign(in1)*in0/(fabs(in1)>(FLOAT4)((FLOAT)0.0000001)?fabs(in1):(FLOAT4)((FLOAT)0.0000001)))*in1", op, backend);
+                    default:
+                        break;
+                }
+                return nullptr;
+            }
         }
         return nullptr;
     }

+ 13 - 0
source/backend/opencl/execution/buffer/LoopBufExecution.hpp

@@ -54,6 +54,19 @@ private:
     std::set<std::string> mBuildOptions;
 };
 
+
+class LoopBinaryBufExecution : public CommonExecution {
+public:
+    LoopBinaryBufExecution(const LoopParam *loop, const std::string &compute, const MNN::Op *op, Backend *bn);
+    virtual ~LoopBinaryBufExecution() = default;
+    virtual ErrorCode onResize(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) override;
+
+private:
+    const LoopParam *mLoop;
+    std::vector<Tensor *> mTensors;
+    std::set<std::string> mBuildOptions;
+};
+
 } // namespace OpenCL
 } // namespace MNN
 #endif /* LoopBufExecution_hpp */

+ 39 - 46
source/backend/opencl/execution/buffer/SoftmaxBufExecution.cpp

@@ -21,10 +21,11 @@ SoftmaxBufExecution::SoftmaxBufExecution(const std::vector<Tensor *> &inputs, in
     mOpenCLBackend = static_cast<OpenCLBackend *>(backend);
 }
 
-bool SoftmaxBufExecution::buildSoftmaxKernel() {
+bool SoftmaxBufExecution::buildSoftmaxKernel(int localSize) {
     auto runtime = mOpenCLBackend->getOpenCLRuntime();
     if (mKernel.get() == nullptr) {
         std::set<std::string> buildOptions;
+        buildOptions.emplace("-DSOFTMAX_LOCAL_SIZE=" + std::to_string(localSize));
         std::string kernelName;
         if (mAxis == 1) {
             mKernel           = runtime->buildKernel("softmax_buf", "softmax_channel", buildOptions);
@@ -39,6 +40,14 @@ bool SoftmaxBufExecution::buildSoftmaxKernel() {
     return true;
 }
 
+int SoftmaxBufExecution::getLocalSize(int size, int maxGroupSize){
+    int local_size = 1;
+    while(local_size * 2 <= maxGroupSize && local_size * 2 <= size){
+        local_size *= 2;
+    }
+    return local_size;
+}
+
 ErrorCode SoftmaxBufExecution::onResize(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) {
     Tensor *input  = inputs[0];
     Tensor *output = outputs[0];
@@ -70,63 +79,47 @@ ErrorCode SoftmaxBufExecution::onResize(const std::vector<Tensor *> &inputs, con
 
     const int channelBlocks  = UP_DIV(outputChannels, 4);
     const int remainChannels = channelBlocks * 4 - outputChannels;
+    auto MaxWorkItems = mOpenCLBackend->getOpenCLRuntime()->getMaxWorkItemSizes();
+    int localSize = getLocalSize(channel, MaxWorkItems[0]);
+    if(localSize < 4){
+        localSize = 1;
+    }
     if(inputBatch == outside && channel == inputChannels && inside == inputWidth * inputHeight){
         mAxis = 1;
-    }else if(inputBatch * inputChannels == outside && channel == inputHeight && inside == inputHeight){
+        localSize = getLocalSize(channelBlocks, MaxWorkItems[0]);
+    }else if(inputBatch * inputChannels == outside && channel == inputHeight && inside == inputWidth){
         mAxis = 2;
     }else if(inputBatch * inputChannels * inputHeight == outside && channel == inputWidth && inside == 1){
         mAxis = 3;
     }
-    buildSoftmaxKernel();
+    buildSoftmaxKernel(localSize);
     
+    cl_int ret = CL_SUCCESS;
+    int shape[] = {outputBatch, channelBlocks, outputHeight, outputWidth};
     if (mAxis == 1) {
-        mGlobalWorkSize = {static_cast<uint32_t>(outputWidth),
-            static_cast<uint32_t>(outputHeight * outputBatch), 1};
-        int shape[] = {outputBatch, channelBlocks, outputHeight, outputWidth};
-
-        uint32_t idx    = 0;
-        cl_int ret = CL_SUCCESS;
-        ret |= mKernel.setArg(idx++, mGlobalWorkSize[0]);
-        ret |= mKernel.setArg(idx++, mGlobalWorkSize[1]);
-        ret |= mKernel.setArg(idx++, mGlobalWorkSize[2]);
-
-        ret |= mKernel.setArg(idx++, openCLBuffer(input));
-        ret |= mKernel.setArg(idx++, openCLBuffer(output));
-        ret |= mKernel.setArg(idx++, static_cast<int>(outputChannels));
-        ret |= mKernel.setArg(idx++, remainChannels);
-        ret |= mKernel.setArg(idx++, shape);
-        MNN_CHECK_CL_SUCCESS(ret, "setArg SoftmaxBufExecution axis_1");
-
-        std::string kernelName = "softmax_buf_channel";
-        mLocalWorkSize =
-        localWS3DDefault(mGlobalWorkSize, mMaxWorkGroupSize, mOpenCLBackend->getOpenCLRuntime(), kernelName, mKernel).first;
+        mGlobalWorkSize = {(uint32_t)(localSize), (uint32_t)outputWidth, (uint32_t)outputHeight * outputBatch};
+
     } else if (mAxis == 2){
-        mGlobalWorkSize = {(uint32_t)channelBlocks*outputWidth, (uint32_t)outputBatch, 1};
-        int shape[] = {outputBatch, channelBlocks, outputHeight, outputWidth};
-        cl_int ret = CL_SUCCESS;
-        ret |= mKernel.setArg(0, openCLBuffer(input));
-        ret |= mKernel.setArg(1, openCLBuffer(output));
-        ret |= mKernel.setArg(2, shape);
-        MNN_CHECK_CL_SUCCESS(ret, "setArg SoftmaxBufExecution axis_2");
-
-        std::string kernelName = "softmax_buf_height";
-        mLocalWorkSize =
-        localWS3DDefault(mGlobalWorkSize, mMaxWorkGroupSize, mOpenCLBackend->getOpenCLRuntime(), kernelName, mKernel).first;
+        mGlobalWorkSize = {(uint32_t)(localSize), (uint32_t)channelBlocks*outputWidth, (uint32_t)outputBatch};
     } else {
         MNN_ASSERT(mAxis == 3);
-        mGlobalWorkSize = {(uint32_t)channelBlocks, (uint32_t)outputBatch*outputHeight, 1};
-        int shape[] = {outputBatch, channelBlocks, outputHeight, outputWidth};
-        cl_int ret = CL_SUCCESS;
-        ret |= mKernel.setArg(0, openCLBuffer(input));
-        ret |= mKernel.setArg(1, openCLBuffer(output));
-        ret |= mKernel.setArg(2, shape);
-        MNN_CHECK_CL_SUCCESS(ret, "setArg SoftmaxBufExecution axis_3");
-
-        std::string kernelName = "softmax_buf_width";
-        mLocalWorkSize =
-        localWS3DDefault(mGlobalWorkSize, mMaxWorkGroupSize, mOpenCLBackend->getOpenCLRuntime(), kernelName, mKernel).first;
+        mGlobalWorkSize = {(uint32_t)(localSize), (uint32_t)channelBlocks, (uint32_t)outputBatch*outputHeight};
+    }
+    mLocalWorkSize = {(uint32_t)(localSize), 1, 1};
+    
+    uint32_t idx    = 0;
+    ret |= mKernel.setArg(idx++, mGlobalWorkSize[0]);
+    ret |= mKernel.setArg(idx++, mGlobalWorkSize[1]);
+    ret |= mKernel.setArg(idx++, mGlobalWorkSize[2]);
+
+    ret |= mKernel.setArg(idx++, openCLImage(input));
+    ret |= mKernel.setArg(idx++, openCLImage(output));
+    ret |= mKernel.setArg(idx++, remainChannels);
+    ret |= mKernel.setArg(idx++, shape);
+    MNN_CHECK_CL_SUCCESS(ret, "setArg SoftmaxBufExecution");
+    if(localSize == 1){
+        mLocalWorkSize = localWS3DDefault(mGlobalWorkSize, mMaxWorkGroupSize, mOpenCLBackend->getOpenCLRuntime(), "softmax_buf", mKernel).first;
     }
-
     return NO_ERROR;
 }
 

+ 2 - 1
source/backend/opencl/execution/buffer/SoftmaxBufExecution.hpp

@@ -26,8 +26,9 @@ public:
     virtual ErrorCode onResize(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) override;
     virtual ErrorCode onExecute(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) override;
 
-    bool buildSoftmaxKernel();
+    bool buildSoftmaxKernel(int localSize);
 private:
+    int getLocalSize(int size, int maxGroupSize);
     cl::Kernel mKernel;
     uint32_t mMaxWorkGroupSize;
     OpenCLBackend *mOpenCLBackend;

+ 1 - 0
source/backend/opencl/execution/cl/binary.cl

@@ -1,6 +1,7 @@
 #ifdef MNN_SUPPORT_FP16
 #pragma OPENCL EXTENSION cl_khr_fp16 : enable
 #endif
+#define PI 3.141592653589f
 __constant sampler_t SAMPLER = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;
 
 __kernel void binary(__private int global_dim0, __private int global_dim1,

+ 1 - 0
source/backend/opencl/execution/cl/binary_buf.cl

@@ -1,6 +1,7 @@
 #ifdef MNN_SUPPORT_FP16
 #pragma OPENCL EXTENSION cl_khr_fp16 : enable
 #endif
+#define PI 3.141592653589f
 
 __kernel void binary_buf(__private int global_dim0, __private int global_dim1,
                          __global FLOAT* input0, __global FLOAT* input1, __global FLOAT* output,

+ 64 - 1
source/backend/opencl/execution/cl/loop.cl

@@ -89,10 +89,17 @@ __kernel void tile(__private int global_dim0, __private int global_dim1, __priva
         const int h = pos.x / width;
         const int c = pos.y << 2;
 
+#ifdef MNN_NHWC
+        const int c_dst_pitch = 1;
+        const int x_dst_pitch = c_dst_pitch * channel;
+        const int y_dst_pitch = x_dst_pitch * width;
+        const int b_dst_pitch = y_dst_pitch * height;
+#else
         const int x_dst_pitch = 1;
         const int y_dst_pitch = x_dst_pitch * width;
         const int c_dst_pitch = y_dst_pitch * height;
         const int b_dst_pitch = c_dst_pitch * channel;
+#endif
         __global FLOAT* dst_ptr = output + pos.z * b_dst_pitch + c * c_dst_pitch + h * y_dst_pitch + w * x_dst_pitch;
         
         FLOAT4 value = RI_F(input, SAMPLER, (int2)(pos.y * width + w, pos.z * height + h));
@@ -118,10 +125,17 @@ __kernel void pack(__private int global_dim0, __private int global_dim1, __priva
         const int h = pos.x / width;
         const int c = pos.y << 2;
 
+#ifdef MNN_NHWC
+        const int c_src_pitch = 1;
+        const int x_src_pitch = c_src_pitch * channel;
+        const int y_src_pitch = x_src_pitch * width;
+        const int b_src_pitch = y_src_pitch * height;
+#else
         const int x_src_pitch = 1;
         const int y_src_pitch = x_src_pitch * width;
         const int c_src_pitch = y_src_pitch * height;
         const int b_src_pitch = c_src_pitch * channel;
+#endif
         __global FLOAT* src_ptr = input + pos.z * b_src_pitch + c * c_src_pitch + h * y_src_pitch + w * x_src_pitch;
         FLOAT4 value = (FLOAT4)0;
         FLOAT *value_ptr = (FLOAT*)&value;
@@ -157,4 +171,53 @@ __kernel void batch_gather(__private int global_dim0, __private int global_dim1,
         int2 offset = index * steps;
         output[offset.x + stride_dst.w + x * stride_dst.x + y * stride_dst.y + pos.y * stride_dst.z] = input[offset.y + stride_src.w + x * stride_src.x + y * stride_src.y + pos.y * stride_src.z];
     }
-}
+}
+
+#ifdef LOOP_BINARY_OPERATOR
+__kernel void broadcast_binary(__private int global_dim0, __private int global_dim1, __private int global_dim2,
+                         __write_only image2d_t output, __read_only image2d_t input0, __read_only image2d_t input1,
+                         __private const int4 src0_size, //(width, height, channel, batch)
+                         __private const int4 src1_size,
+                         __private const int dst_width, __private const int dst_height,
+                         __private const int channel_block) {
+    int3 pos = (int3)(get_global_id(0), get_global_id(1), get_global_id(2));
+    
+    if (pos.x < global_dim0 && pos.y < global_dim1 && pos.z < global_dim2) {
+        
+        const int w = pos.x;
+        const int h = pos.y;
+        const int c = pos.z % channel_block;
+        const int n = pos.z / channel_block;
+        
+        FLOAT4 in0 = RI_F(input0, SAMPLER, (int2)(c * src0_size.x + w, n * src0_size.y + h));
+#ifdef BROADCAST_CHANNEL
+        const int w1 = w % src1_size.x;
+        const int h1 = h % src1_size.y;
+        const int n1 = n % src1_size.w;
+        const int c1 = c << 2;
+        int4 c1_vec = (int4)(c1, c1 + 1, c1 + 2, c1 + 3);
+        c1_vec = c1_vec % (int4)(src1_size.z);
+        int4 c4_vec = (c1_vec + 3) / 4;
+        FLOAT4 in1;
+        FLOAT* in1_ptr = (FLOAT*)&in1;
+        int* c1_vec_prt = (int*)&c1_vec;
+        int* c4_vec_prt = (int*)&c4_vec;
+        for(int i = 0; i < 4; ++i){
+            int remain = (c4_vec_prt[i] << 2) - c1_vec_prt[i];
+            FLOAT4 tmp = RI_F(input1, SAMPLER, (int2)(c4_vec_prt[i] * src1_size.x + w1, n1 * src1_size.y + h1));
+            FLOAT* tmp_ptr = (FLOAT*)&tmp;
+            in1_ptr[i] = tmp_ptr[remain];
+        }
+#else
+        const int w1 = w % src1_size.x;
+        const int h1 = h % src1_size.y;
+        const int c1 = c;
+        const int n1 = n % src1_size.w;
+        FLOAT4 in1 = RI_F(input1, SAMPLER, (int2)(c1 * src1_size.x + w1, n1 * src1_size.y + h1));
+#endif
+        FLOAT4 out = CONVERT_FLOAT4(LOOP_BINARY_OPERATOR);
+        WI_F(output, (int2)(c * dst_width + w, n * dst_height + h), out);
+    }
+}
+#endif
+

+ 56 - 95
source/backend/opencl/execution/cl/loop_buf.cl

@@ -1,80 +1,6 @@
 #ifdef MNN_SUPPORT_FP16
 #pragma OPENCL EXTENSION cl_khr_fp16 : enable
 #endif
-
-__kernel void batch_matmul_buf(__private int global_dim0, __private int global_dim1, __private int global_dim2,
-                         __global FLOAT* output, __global FLOAT* input_A, __global FLOAT* input_B,
-#ifdef BIAS
-                        __global FLOAT* input_C,
-#endif
-                        __global FLOAT* offset_O, __global FLOAT* offset_A, __global FLOAT* offset_B,
-#ifdef BIAS
-                        __global FLOAT* offset_C,
-#endif
-                         __private const int e,
-                         __private const int l,
-                         __private const int h,
-                         __private const int4 offsets,
-                         __private const int4 iters,
-                         __private const int4 steps) {
-    int3 pos = (int3)(get_global_id(0), get_global_id(1), get_global_id(2));
-    
-    if (pos.x < global_dim0 && pos.y < global_dim1 && pos.z < global_dim2) {
-        int4 index = (int4)(pos.z);
-        if (iters.x >= 0) {
-            index.x = (int)(offset_O[pos.z]);
-        }
-        if (iters.y >= 0) {
-            index.y = (int)(offset_A[pos.z]);
-        }
-        if (iters.z >= 0) {
-            index.z = (int)(offset_B[pos.z]);
-        }
-#ifdef BIAS
-        if (iters.w >= 0) {
-            index.w = (int)(offset_C[pos.z]);
-        }
-#endif
-        int4 offset = index * steps + offsets;
-        
-#if TRANSPOSE_A
-        __global FLOAT* A_ptr = input_A + offset.y + pos.y;
-#else
-        __global FLOAT* A_ptr = input_A + offset.y + pos.y * l;
-#endif
-
-#if TRANSPOSE_B
-        __global FLOAT* B_ptr = input_B + offset.z + pos.x * l;
-#else
-        __global FLOAT* B_ptr = input_B + offset.z + pos.x;
-#endif
-
-#ifdef BIAS
-        FLOAT value = input_C[offset.w + pos.x];
-#else
-        FLOAT value = 0;
-#endif
-
-        for(int i = 0; i < l; ++i){
-#if TRANSPOSE_A
-            FLOAT value_a = A_ptr[i * e];
-#else
-            FLOAT value_a = A_ptr[i];
-#endif
-
-#if TRANSPOSE_B
-            FLOAT value_b = B_ptr[i];
-#else
-            FLOAT value_b = B_ptr[i * h];
-#endif
-
-            value = mad(value_a, value_b, value);
-        }
-
-        output[offset.x + pos.y * h + pos.x] = value;
-    }
-}
-
 __kernel void tile_buf(__private int global_dim0, __private int global_dim1, __private int global_dim2,
                         __global FLOAT* input, __global FLOAT* output,
                         __private const int width,
@@ -89,11 +15,17 @@ __kernel void tile_buf(__private int global_dim0, __private int global_dim1, __p
         const int y_src_pitch = x_src_pitch * width;
         const int c_src_pitch = y_src_pitch * height;
         const int b_src_pitch = c_src_pitch * ((channel + 3) / 4);
-
+#ifdef MNN_NHWC
+        const int c_dst_pitch = 1;
+        const int x_dst_pitch = c_dst_pitch * channel;
+        const int y_dst_pitch = x_dst_pitch * width;
+        const int b_dst_pitch = y_dst_pitch * height;
+#else
         const int x_dst_pitch = 1;
         const int y_dst_pitch = x_dst_pitch * width;
         const int c_dst_pitch = y_dst_pitch * height;
         const int b_dst_pitch = c_dst_pitch * channel;
+#endif
         __global FLOAT* dst_ptr = output + pos.z * b_dst_pitch + c * c_dst_pitch + h * y_dst_pitch + w * x_dst_pitch;
 
         FLOAT4 value = vload4(0, input + pos.z * b_src_pitch + pos.y * c_src_pitch + h * y_src_pitch + w * x_src_pitch);
@@ -121,11 +53,17 @@ __kernel void pack_buf(__private int global_dim0, __private int global_dim1, __p
         const int y_dst_pitch = x_dst_pitch * width;
         const int c_dst_pitch = y_dst_pitch * height;
         const int b_dst_pitch = c_dst_pitch * ((channel + 3) / 4);
-
+#ifdef MNN_NHWC
+        const int c_src_pitch = 1;
+        const int x_src_pitch = c_src_pitch * channel;
+        const int y_src_pitch = x_src_pitch * width;
+        const int b_src_pitch = y_src_pitch * height;
+#else
         const int x_src_pitch = 1;
         const int y_src_pitch = x_src_pitch * width;
         const int c_src_pitch = y_src_pitch * height;
         const int b_src_pitch = c_src_pitch * channel;
+#endif
         __global FLOAT* src_ptr = input + pos.z * b_src_pitch + c * c_src_pitch + h * y_src_pitch + w * x_src_pitch;
         FLOAT4 value = (FLOAT4)0;
         FLOAT *value_ptr = (FLOAT*)&value;
@@ -136,29 +74,52 @@ __kernel void pack_buf(__private int global_dim0, __private int global_dim1, __p
     }
 }
 
-__kernel void batch_gather_buf(__private int global_dim0, __private int global_dim1, __private int global_dim2,
-                         __global FLOAT* output, __global FLOAT* input,
-                         __global FLOAT* offset_dst, __global FLOAT* offset_src,
-                         __private const int x_size,
-                         __private const int4 stride_src,
-                         __private const int4 stride_dst,
-                         __private const int2 steps,
-                         __private const int2 iters) {
+#ifdef LOOP_BINARY_OPERATOR
+__kernel void broadcast_binary_buf(__private int global_dim0, __private int global_dim1, __private int global_dim2,
+                         __global FLOAT* output, __global FLOAT* input0, __global FLOAT* input1,
+                         __private const int4 src0_size, //(width, height, channel, batch)
+                         __private const int4 src1_size,
+                         __private const int dst_width, __private const int dst_height,
+                         __private const int channel_block) {
     int3 pos = (int3)(get_global_id(0), get_global_id(1), get_global_id(2));
     
     if (pos.x < global_dim0 && pos.y < global_dim1 && pos.z < global_dim2) {
         
-        int x = pos.x % x_size;
-        int y = pos.x / x_size;
-
-        int2 index = (int2)(pos.z, pos.z);
-        if (iters.x >= 0) {
-            index.x = (int)(offset_dst[pos.z]);
-        }
-        if (iters.y >= 0) {
-            index.y = (int)(offset_src[pos.z]);
+        const int w = pos.x;
+        const int h = pos.y;
+        const int c = pos.z % channel_block;
+        const int n = pos.z / channel_block;
+        const int src0_channel_block = (src0_size.z + 3) / 4;
+        const int src1_channel_block = (src1_size.z + 3) / 4;
+        
+        FLOAT4 in0 = vload4(0, input0 + ((((n * src0_channel_block) + c) * src0_size.y + h) * src0_size.x + w) * 4);
+#ifdef BROADCAST_CHANNEL
+        const int w1 = w % src1_size.x;
+        const int h1 = h % src1_size.y;
+        const int n1 = n % src1_size.w;
+        const int c1 = c << 2;
+        int4 c1_vec = (int4)(c1, c1 + 1, c1 + 2, c1 + 3);
+        c1_vec = c1_vec % (int4)(src1_size.z);
+        int4 c4_vec = (c1_vec + 3) / 4;
+        FLOAT4 in1;
+        FLOAT* in1_ptr = (FLOAT*)&in1;
+        int* c1_vec_prt = (int*)&c1_vec;
+        int* c4_vec_prt = (int*)&c4_vec;
+        for(int i = 0; i < 4; ++i){
+            int remain = (c4_vec_prt[i] << 2) - c1_vec_prt[i];
+            FLOAT4 tmp = vload4(0, input1 + ((((n1 * src1_channel_block) + c4_vec_prt[i]) * src1_size.y + h1) * src1_size.x + w1) * 4);
+            FLOAT* tmp_ptr = (FLOAT*)&tmp;
+            in1_ptr[i] = tmp_ptr[remain];
         }
-        int2 offset = index * steps;
-        output[offset.x + stride_dst.w + x * stride_dst.x + y * stride_dst.y + pos.y * stride_dst.z] = input[offset.y + stride_src.w + x * stride_src.x + y * stride_src.y + pos.y * stride_src.z];
+#else
+        const int w1 = w % src1_size.x;
+        const int h1 = h % src1_size.y;
+        const int c1 = c;
+        const int n1 = n % src1_size.w;
+        FLOAT4 in1 = vload4(0, input1 + ((((n1 * src1_channel_block) + c1) * src1_size.y + h1) * src1_size.x + w1) * 4);
+#endif
+        FLOAT4 out = CONVERT_FLOAT4(LOOP_BINARY_OPERATOR);
+        vstore4(out, 0, output + ((((n * channel_block) + c) * dst_height + h) * dst_width + w) * 4);
     }
 }
+#endif

+ 8 - 16
source/backend/opencl/execution/cl/matmul_buf.cl

@@ -171,10 +171,11 @@ __kernel void matmul_transA_buf(GLOBAL_SIZE_2_DIMS __global const FLOAT* input_a
     for (short pos = 0; pos < channel_blocks; pos += 1) {
 
         const int inpa_offset = (4*pos) * height_blocks + height_blocks_idx;
+        short remain = (pos + 1) * 4 - channels;
         FLOAT4 a0 = vload4(inpa_offset, input_a);
-        FLOAT4 a1 = vload4(inpa_offset + height_blocks, input_a);
-        FLOAT4 a2 = vload4(inpa_offset + height_blocks*2, input_a);
-        FLOAT4 a3 = vload4(inpa_offset + height_blocks*3, input_a);
+        FLOAT4 a1 = ((remain >= 3) ? v_zero : vload4(inpa_offset + height_blocks, input_a));
+        FLOAT4 a2 = ((remain >= 2) ? v_zero : vload4(inpa_offset + height_blocks*2, input_a));
+        FLOAT4 a3 = ((remain >= 1) ? v_zero : vload4(inpa_offset + height_blocks*3, input_a));
 
         const int inpb_offset = (4*pos) * width_blocks + width_blocks_idx;
         FLOAT4 b0 = vload4(inpb_offset, input_b);
@@ -182,11 +183,6 @@ __kernel void matmul_transA_buf(GLOBAL_SIZE_2_DIMS __global const FLOAT* input_a
         FLOAT4 b2 = vload4(inpb_offset + width_blocks*2, input_b);
         FLOAT4 b3 = vload4(inpb_offset + width_blocks*3, input_b);
 
-        short remain = (pos + 1) * 4 - channels;
-        a3 = ((remain >= 1) ? v_zero : a3);
-        a2 = ((remain >= 2) ? v_zero : a2);
-        a1 = ((remain >= 3) ? v_zero : a1);
-
         FLOAT4 a0_trans = (FLOAT4)(a0.x, a1.x, a2.x, a3.x);
         FLOAT4 a1_trans = (FLOAT4)(a0.y, a1.y, a2.y, a3.y);
         FLOAT4 a2_trans = (FLOAT4)(a0.z, a1.z, a2.z, a3.z);
@@ -261,10 +257,11 @@ __kernel void matmul_transA_transB_buf(GLOBAL_SIZE_2_DIMS __global const FLOAT*
 
     for (short pos = 0; pos < channel_blocks; pos += 1) {
         const int inpa_offset = (4*pos) * height_blocks + height_blocks_idx;
+        short remain = (pos + 1) * 4 - channels;
         FLOAT4 a0 = vload4(inpa_offset, input_a);
-        FLOAT4 a1 = vload4(inpa_offset + height_blocks, input_a);
-        FLOAT4 a2 = vload4(inpa_offset + height_blocks*2, input_a);
-        FLOAT4 a3 = vload4(inpa_offset + height_blocks*3, input_a);
+        FLOAT4 a1 = ((remain >= 3) ? v_zero : vload4(inpa_offset + height_blocks, input_a));
+        FLOAT4 a2 = ((remain >= 2) ? v_zero : vload4(inpa_offset + height_blocks*2, input_a));
+        FLOAT4 a3 = ((remain >= 1) ? v_zero : vload4(inpa_offset + height_blocks*3, input_a));
 
         const int inpb_offset = (4*width_blocks_idx) * channel_blocks + pos;
         FLOAT4 b0 = vload4(inpb_offset, input_b);
@@ -272,11 +269,6 @@ __kernel void matmul_transA_transB_buf(GLOBAL_SIZE_2_DIMS __global const FLOAT*
         FLOAT4 b2 = vload4(inpb_offset + channel_blocks*2, input_b);
         FLOAT4 b3 = vload4(inpb_offset + channel_blocks*3, input_b);
 
-        short remain = (pos + 1) * 4 - channels;
-        a3 = ((remain >= 1) ? v_zero : a3);
-        a2 = ((remain >= 2) ? v_zero : a2);
-        a1 = ((remain >= 3) ? v_zero : a1);
-
         FLOAT4 a0_trans = (FLOAT4)(a0.x, a1.x, a2.x, a3.x);
         FLOAT4 a1_trans = (FLOAT4)(a0.y, a1.y, a2.y, a3.y);
         FLOAT4 a2_trans = (FLOAT4)(a0.z, a1.z, a2.z, a3.z);

Rozdielové dáta súboru neboli zobrazené, pretože súbor je príliš veľký
+ 7 - 7
source/backend/opencl/execution/cl/opencl_program.cc


+ 249 - 105
source/backend/opencl/execution/cl/softmax.cl

@@ -14,124 +14,268 @@
 __constant sampler_t SAMPLER = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;
 
 
-__kernel void softmax_channel(GLOBAL_SIZE_3_DIMS __read_only image2d_t input, __write_only image2d_t output, __private const int output_channels,
+__kernel void softmax_channel(GLOBAL_SIZE_3_DIMS __read_only image2d_t input, __write_only image2d_t output,
                               __private const int remain_channels, __private const int4 shape // NCHW
                               ) {
 
-    const int width_idx    = get_global_id(0);
-    const int batch_height_idx       = get_global_id(1);
+    const int x = get_global_id(0);
+    const int w = get_global_id(1);
+    const int bh = get_global_id(2);
+    DEAL_NON_UNIFORM_DIM3(x, w, bh);
+#if SOFTMAX_LOCAL_SIZE >= 4
+    int lid = get_local_id(0);
+    FLOAT4 local sum[SOFTMAX_LOCAL_SIZE];
 
+    FLOAT4 maxValue = (FLOAT4)-FLT_MAX;
+    for (int i = lid; i < shape.y - 1; i+=SOFTMAX_LOCAL_SIZE) {
+        maxValue = fmax(maxValue, RI_F(input, SAMPLER, (int2)(w + i * shape.w, bh)));
+    }
+
+    sum[lid] = maxValue;
+    barrier(CLK_LOCAL_MEM_FENCE);
+    for(int i = SOFTMAX_LOCAL_SIZE/2; i > 0; i /= 2){
+        if (lid < i)
+            sum[lid] = fmax(sum[lid], sum[lid + i]);
+        barrier(CLK_LOCAL_MEM_FENCE);
+    }
+    maxValue = sum[0];
+
+    maxValue.x = fmax(maxValue.x, maxValue.y);
+    maxValue.x = fmax(maxValue.x, maxValue.z);
+    maxValue.x = fmax(maxValue.x, maxValue.w);
+
+    FLOAT4 input_data = RI_F(input, SAMPLER, (int2)(w + (shape.y - 1) * shape.w , bh));
+    if (remain_channels == 0) {
+        maxValue.x = fmax(maxValue.x, input_data.x);
+        maxValue.x = fmax(maxValue.x, input_data.y);
+        maxValue.x = fmax(maxValue.x, input_data.z);
+        maxValue.x = fmax(maxValue.x, input_data.w);
+    } else if (remain_channels == 1) {
+        maxValue.x = fmax(maxValue.x, input_data.z);
+        maxValue.x = fmax(maxValue.x, input_data.y);
+        maxValue.x = fmax(maxValue.x, input_data.x);
+    } else if (remain_channels == 2) {
+        maxValue.x = fmax(maxValue.x, input_data.y);
+        maxValue.x = fmax(maxValue.x, input_data.x);
+    } else if (remain_channels == 3) {
+        maxValue.x = fmax(maxValue.x, input_data.x);
+    }
+
+    FLOAT4 sumValue = (FLOAT4)0;
+    for (int i = lid; i < shape.y - 1; i+=SOFTMAX_LOCAL_SIZE) {
+        sumValue += exp(RI_F(input, SAMPLER, (int2)(w + i * shape.w, bh)) - (FLOAT4)maxValue.x);
+    }
+    sum[lid] = sumValue;
+    barrier(CLK_LOCAL_MEM_FENCE);
+    for(int i = SOFTMAX_LOCAL_SIZE/2; i > 0; i /= 2){
+        if (lid < i)
+            sum[lid] = sum[lid] + sum[lid + i];
+        barrier(CLK_LOCAL_MEM_FENCE);
+    }
+    sumValue = sum[0];
+    sumValue.x = sumValue.x + sumValue.y + sumValue.z + sumValue.w;
+    
+    
+    input_data -= maxValue.x;
+    if (remain_channels == 0) {
+        sumValue.x += exp(input_data.w);
+        sumValue.x += exp(input_data.z);
+        sumValue.x += exp(input_data.y);
+        sumValue.x += exp(input_data.x);
+    } else if (remain_channels == 1) {
+        sumValue.x += exp(input_data.z);
+        sumValue.x += exp(input_data.y);
+        sumValue.x += exp(input_data.x);
+    } else if (remain_channels == 2) {
+        sumValue.x += exp(input_data.y);
+        sumValue.x += exp(input_data.x);
+    } else if (remain_channels == 3) {
+        sumValue.x += exp(input_data.x);
+    }
+    for(int i = lid; i < shape.y; i+=SOFTMAX_LOCAL_SIZE){
+        FLOAT4 value = exp(RI_F(input, SAMPLER, (int2)(w + i * shape.w, bh)) - maxValue.x) / sumValue.x;
+        WI_F(output, (int2)(w + i * shape.w, bh), value);
+    }
+#else
+    FLOAT4 maxValue = (FLOAT4)-FLT_MAX;
+    for (int i = 0; i < shape.y - 1; i++) {
+        maxValue = fmax(maxValue, RI_F(input, SAMPLER, (int2)(w + i * shape.w, bh)));
+    }
     
-    if (width_idx < shape.w && batch_height_idx < shape.x*shape.z) {
-
-        FLOAT4 float_max_value = (FLOAT4)-FLT_MAX;
-        FLOAT4 input_data;
-        for (short i = 0; i < shape.y - 1; ++i) {
-            input_data      = RI_F(input, SAMPLER, (int2)(width_idx + i * shape.w, batch_height_idx));
-            float_max_value = max(float_max_value, input_data);
-        }
-        float_max_value.x = max(float_max_value.x, float_max_value.y);
-        float_max_value.x = max(float_max_value.x, float_max_value.z);
-        float_max_value.x = max(float_max_value.x, float_max_value.w);
-
-        input_data = RI_F(input, SAMPLER, (int2)(width_idx + (shape.y - 1) * shape.w , batch_height_idx));
-        if (remain_channels == 0) {
-            float_max_value.x = max(float_max_value.x, input_data.x);
-            float_max_value.x = max(float_max_value.x, input_data.y);
-            float_max_value.x = max(float_max_value.x, input_data.z);
-            float_max_value.x = max(float_max_value.x, input_data.w);
-        } else if (remain_channels == 1) {
-            float_max_value.x = max(float_max_value.x, input_data.z);
-            float_max_value.x = max(float_max_value.x, input_data.y);
-            float_max_value.x = max(float_max_value.x, input_data.x);
-        } else if (remain_channels == 2) {
-            float_max_value.x = max(float_max_value.x, input_data.y);
-            float_max_value.x = max(float_max_value.x, input_data.x);
-        } else if (remain_channels == 3) {
-            float_max_value.x = max(float_max_value.x, input_data.x);
-        }
-
-
-        FLOAT4 accum_result       = 0;
-        for (short i = 0; i < shape.y - 1; ++i) {
-            input_data = RI_F(input, SAMPLER, (int2)(width_idx + i * shape.w, batch_height_idx));
-            input_data = EXP(input_data - float_max_value.x);
-            accum_result += input_data;
-        }
-        accum_result.x = accum_result.x + accum_result.y + accum_result.z + accum_result.w;
-
-        input_data = RI_F(input, SAMPLER, (int2)(width_idx + (shape.y - 1) * shape.w, batch_height_idx));
-        input_data -= float_max_value.x;
-        if (remain_channels == 0) {
-            accum_result.x += EXP(input_data.w);
-            accum_result.x += EXP(input_data.z);
-            accum_result.x += EXP(input_data.y);
-            accum_result.x += EXP(input_data.x);
-        } else if (remain_channels == 1) {
-            accum_result.x += EXP(input_data.z);
-            accum_result.x += EXP(input_data.y);
-            accum_result.x += EXP(input_data.x);
-        } else if (remain_channels == 2) {
-            accum_result.x += EXP(input_data.y);
-            accum_result.x += EXP(input_data.x);
-        } else if (remain_channels == 3) {
-            accum_result.x += EXP(input_data.x);
-        }
-        
-        for(int i = 0; i < shape.y; ++i){
-            int cur_out_width_pos  = mad24(i, shape.w, width_idx);
-            input_data = RI_F(input, SAMPLER, (int2)(cur_out_width_pos, batch_height_idx)) - float_max_value.x;
-            input_data = EXP(input_data) / accum_result.x;
-            WI_F(output, (int2)(cur_out_width_pos, batch_height_idx), input_data);
-        }
+    maxValue.x = fmax(maxValue.x, maxValue.y);
+    maxValue.x = fmax(maxValue.x, maxValue.z);
+    maxValue.x = fmax(maxValue.x, maxValue.w);
+
+    FLOAT4 input_data = RI_F(input, SAMPLER, (int2)(w + (shape.y - 1) * shape.w , bh));
+    if (remain_channels == 0) {
+        maxValue.x = fmax(maxValue.x, input_data.x);
+        maxValue.x = fmax(maxValue.x, input_data.y);
+        maxValue.x = fmax(maxValue.x, input_data.z);
+        maxValue.x = fmax(maxValue.x, input_data.w);
+    } else if (remain_channels == 1) {
+        maxValue.x = fmax(maxValue.x, input_data.z);
+        maxValue.x = fmax(maxValue.x, input_data.y);
+        maxValue.x = fmax(maxValue.x, input_data.x);
+    } else if (remain_channels == 2) {
+        maxValue.x = fmax(maxValue.x, input_data.y);
+        maxValue.x = fmax(maxValue.x, input_data.x);
+    } else if (remain_channels == 3) {
+        maxValue.x = fmax(maxValue.x, input_data.x);
+    }
+
+    FLOAT4 sumValue = (FLOAT4)0;
+    for (int i = 0; i < shape.y - 1; i++) {
+        sumValue += exp(RI_F(input, SAMPLER, (int2)(w + i * shape.w, bh)) - (FLOAT4)maxValue.x);
     }
+    sumValue.x = sumValue.x + sumValue.y + sumValue.z + sumValue.w;
+    input_data -= maxValue.x;
+    if (remain_channels == 0) {
+        sumValue.x += exp(input_data.w);
+        sumValue.x += exp(input_data.z);
+        sumValue.x += exp(input_data.y);
+        sumValue.x += exp(input_data.x);
+    } else if (remain_channels == 1) {
+        sumValue.x += exp(input_data.z);
+        sumValue.x += exp(input_data.y);
+        sumValue.x += exp(input_data.x);
+    } else if (remain_channels == 2) {
+        sumValue.x += exp(input_data.y);
+        sumValue.x += exp(input_data.x);
+    } else if (remain_channels == 3) {
+        sumValue.x += exp(input_data.x);
+    }
+    for(int i = 0; i < shape.y; i++){
+        FLOAT4 value = exp(RI_F(input, SAMPLER, (int2)(w + i * shape.w, bh)) - maxValue.x) / sumValue.x;
+        WI_F(output, (int2)(w + i * shape.w, bh), value);
+    }
+#endif
 }
 
-__kernel void softmax_height(__read_only image2d_t input, __write_only image2d_t output,
-                      __private const int4 shape // NCHW
+__kernel void softmax_height(GLOBAL_SIZE_3_DIMS __read_only image2d_t input, __write_only image2d_t output,
+                      __private const int remain_channels, __private const int4 shape // NCHW
                       ) {
-    int wc = get_global_id(0);
-    int b = get_global_id(1);
-    if (wc < shape.y*shape.w && b < shape.x) {
-        /*Compute Max */
-        FLOAT4 maxValue = RI_F(input, SAMPLER, (int2)(wc, b*shape.z));
-        for (int i=1; i<shape.z; ++i) {
-            maxValue = fmax(maxValue, RI_F(input, SAMPLER, (int2)(wc, b*shape.z+i)));
-        }
-        /*Compute Exp Sum*/
-        FLOAT4 sumValue = (FLOAT4)0;
-        for (int i=0; i<shape.z; ++i) {
-            sumValue += exp(RI_F(input, SAMPLER, (int2)(wc, b*shape.z+i)) - maxValue);
-        }
-        /*Compute Result */
-        for (int i=0; i<shape.z; ++i) {
-            FLOAT4 value = exp(RI_F(input, SAMPLER, (int2)(wc, b*shape.z+i)) - maxValue) / sumValue;
-            WI_F(output, (int2)(wc, b*shape.z+i), value);
-        }
+    const int x = get_global_id(0);
+    const int wc = get_global_id(1);
+    const int b = get_global_id(2);
+    DEAL_NON_UNIFORM_DIM3(x, wc, b);
+#if SOFTMAX_LOCAL_SIZE >= 4
+    int lid = get_local_id(0);
+    FLOAT4 local sum[SOFTMAX_LOCAL_SIZE];
+    /*Compute Max */
+    FLOAT4 maxValue = (FLOAT4)(-FLT_MAX);
+    for (int i=lid; i<shape.z; i+=SOFTMAX_LOCAL_SIZE) {
+        maxValue = fmax(maxValue, RI_F(input, SAMPLER, (int2)(wc, b*shape.z+i)));
+    }
+    sum[lid] = maxValue;
+    barrier(CLK_LOCAL_MEM_FENCE);
+    for(int i = SOFTMAX_LOCAL_SIZE/2; i > 0; i /= 2){
+        if (lid < i)
+            sum[lid] = fmax(sum[lid], sum[lid + i]);
+        barrier(CLK_LOCAL_MEM_FENCE);
+    }
+    maxValue = sum[0];
+    
+    /*Compute Exp Sum*/
+    FLOAT4 sumValue = (FLOAT4)0;
+    for (int i=lid; i<shape.z; i+=SOFTMAX_LOCAL_SIZE) {
+        sumValue += exp(RI_F(input, SAMPLER, (int2)(wc, b*shape.z+i)) - maxValue);
+    }
+    sum[lid] = sumValue;
+    barrier(CLK_LOCAL_MEM_FENCE);
+    for(int i = SOFTMAX_LOCAL_SIZE/2; i > 0; i /= 2){
+        if (lid < i)
+            sum[lid] = sum[lid] + sum[lid + i];
+        barrier(CLK_LOCAL_MEM_FENCE);
+    }
+    sumValue = sum[0];
+    
+    /*Compute Result */
+    for (int i=lid; i<shape.z; i+=SOFTMAX_LOCAL_SIZE) {
+        FLOAT4 value = exp(RI_F(input, SAMPLER, (int2)(wc, b*shape.z+i)) - maxValue) / sumValue;
+        WI_F(output, (int2)(wc, b*shape.z+i), value);
+    }
+#else
+    /*Compute Max */
+    FLOAT4 maxValue = (FLOAT4)(-FLT_MAX);
+    for (int i=0; i<shape.z; i++) {
+        maxValue = fmax(maxValue, RI_F(input, SAMPLER, (int2)(wc, b*shape.z+i)));
     }
+    
+    /*Compute Exp Sum*/
+    FLOAT4 sumValue = (FLOAT4)0;
+    for (int i=0; i<shape.z; i++) {
+        sumValue += exp(RI_F(input, SAMPLER, (int2)(wc, b*shape.z+i)) - maxValue);
+    }
+    
+    /*Compute Result */
+    for (int i=0; i<shape.z; i++) {
+        FLOAT4 value = exp(RI_F(input, SAMPLER, (int2)(wc, b*shape.z+i)) - maxValue) / sumValue;
+        WI_F(output, (int2)(wc, b*shape.z+i), value);
+    }
+#endif
 }
 
 
-__kernel void softmax_width(__read_only image2d_t input, __write_only image2d_t output,
-                      __private const int4 shape // NCHW
+__kernel void softmax_width(GLOBAL_SIZE_3_DIMS __read_only image2d_t input, __write_only image2d_t output,
+                      __private const int remain_channels, __private const int4 shape // NCHW
                       ) {
-    int c = get_global_id(0);
-    int bh = get_global_id(1);
-    if (c < shape.y && bh < shape.x*shape.z) {
-        /*Compute Max */
-        FLOAT4 maxValue = RI_F(input, SAMPLER, (int2)(c*shape.w, bh));
-        for (int i=1; i<shape.w; ++i) {
-            maxValue = fmax(maxValue, RI_F(input, SAMPLER, (int2)(c*shape.w+i, bh)));
-        }
-        /*Compute Exp Sum*/
-        FLOAT4 sumValue = (FLOAT4)0;
-        for (int i=0; i<shape.w; ++i) {
-            sumValue += exp(RI_F(input, SAMPLER, (int2)(c*shape.w+i, bh)) - maxValue);
-        }
-        /*Compute Result */
-        for (int i=0; i<shape.w; ++i) {
-            FLOAT4 value = exp(RI_F(input, SAMPLER, (int2)(c*shape.w+i, bh)) - maxValue) / sumValue;
-            WI_F(output, (int2)(c*shape.w+i, bh), value);
-        }
+    const int x = get_global_id(0);
+    const int c = get_global_id(1);
+    const int bh = get_global_id(2);
+    DEAL_NON_UNIFORM_DIM3(x, c, bh);
+#if SOFTMAX_LOCAL_SIZE >= 4
+    int lid = get_local_id(0);
+    FLOAT4 local sum[SOFTMAX_LOCAL_SIZE];
+    
+    /*Compute Max */
+    FLOAT4 maxValue = (FLOAT4)(-FLT_MAX);
+    for (int i=lid; i<shape.w; i+=SOFTMAX_LOCAL_SIZE) {
+        maxValue = fmax(maxValue, RI_F(input, SAMPLER, (int2)(c*shape.w+i, bh)));
     }
+    sum[lid] = maxValue;
+    barrier(CLK_LOCAL_MEM_FENCE);
+    for(int i = SOFTMAX_LOCAL_SIZE/2; i > 0; i /= 2){
+        if (lid < i)
+            sum[lid] = fmax(sum[lid], sum[lid + i]);
+        barrier(CLK_LOCAL_MEM_FENCE);
+    }
+    maxValue = sum[0];
+    
+    /*Compute Exp Sum*/
+    FLOAT4 sumValue = (FLOAT4)0;
+    for (int i=lid; i<shape.z; i+=SOFTMAX_LOCAL_SIZE) {
+        sumValue += exp(RI_F(input, SAMPLER, (int2)(c*shape.w+i, bh)) - maxValue);
+    }
+    sum[lid] = sumValue;
+    barrier(CLK_LOCAL_MEM_FENCE);
+    for(int i = SOFTMAX_LOCAL_SIZE/2; i > 0; i /= 2){
+        if (lid < i)
+            sum[lid] = sum[lid] + sum[lid + i];
+        barrier(CLK_LOCAL_MEM_FENCE);
+    }
+    sumValue = sum[0];
+    
+    /*Compute Result */
+    for (int i=lid; i<shape.w; i+=SOFTMAX_LOCAL_SIZE) {
+        FLOAT4 value = exp(RI_F(input, SAMPLER, (int2)(c*shape.w+i, bh)) - maxValue) / sumValue;
+        WI_F(output, (int2)(c*shape.w+i, bh), value);
+    }
+#else
+    /*Compute Max */
+    FLOAT4 maxValue = (FLOAT4)(-FLT_MAX);
+    for (int i=0; i<shape.w; i++) {
+        maxValue = fmax(maxValue, RI_F(input, SAMPLER, (int2)(c*shape.w+i, bh)));
+    }
+    /*Compute Exp Sum*/
+    FLOAT4 sumValue = (FLOAT4)0;
+    for (int i=0; i<shape.z; i++) {
+        sumValue += exp(RI_F(input, SAMPLER, (int2)(c*shape.w+i, bh)) - maxValue);
+    }
+    
+    /*Compute Result */
+    for (int i=0; i<shape.w; i++) {
+        FLOAT4 value = exp(RI_F(input, SAMPLER, (int2)(c*shape.w+i, bh)) - maxValue) / sumValue;
+        WI_F(output, (int2)(c*shape.w+i, bh), value);
+    }
+#endif
 }

+ 256 - 108
source/backend/opencl/execution/cl/softmax_buf.cl

@@ -15,138 +15,286 @@
 __kernel void softmax_channel(GLOBAL_SIZE_3_DIMS
                               __global const FLOAT *input,
                               __global FLOAT *output,
-                              __private const int output_channels,
                               __private const int remain_channels,
                               __private const int4 shape) {//NCHW
 
-    const int width_idx    = get_global_id(0);
-    const int batch_height_idx       = get_global_id(1);
-
-    if (width_idx < shape.w && batch_height_idx < shape.x*shape.z) {
-        const int batch_idx = batch_height_idx / shape.z;
-        const int height_idx = batch_height_idx % shape.z;
-        const int offset = (((batch_idx*shape.y+0)*shape.z+height_idx)*shape.w+width_idx)*4;
-
-        FLOAT4 float_max_value = (FLOAT4)-FLT_MAX;
-        FLOAT4 input_data;
-        for (short i = 0; i < shape.y - 1; ++i) {
-            input_data      = vload4(i*shape.z*shape.w, input+offset);
-            float_max_value = max(float_max_value, input_data);
-        }
-        
-        float_max_value.x = max(float_max_value.x, float_max_value.y);
-        float_max_value.x = max(float_max_value.x, float_max_value.z);
-        float_max_value.x = max(float_max_value.x, float_max_value.w);
-
-        input_data = vload4((shape.y - 1)*shape.z*shape.w, input+offset);
-        if (remain_channels == 0) {
-            float_max_value.x = max(float_max_value.x, input_data.x);
-            float_max_value.x = max(float_max_value.x, input_data.y);
-            float_max_value.x = max(float_max_value.x, input_data.z);
-            float_max_value.x = max(float_max_value.x, input_data.w);
-        } else if (remain_channels == 1) {
-            float_max_value.x = max(float_max_value.x, input_data.z);
-            float_max_value.x = max(float_max_value.x, input_data.y);
-            float_max_value.x = max(float_max_value.x, input_data.x);
-        } else if (remain_channels == 2) {
-            float_max_value.x = max(float_max_value.x, input_data.y);
-            float_max_value.x = max(float_max_value.x, input_data.x);
-        } else if (remain_channels == 3) {
-            float_max_value.x = max(float_max_value.x, input_data.x);
-        }
-
-        FLOAT4 accum_result       = 0;
-        for (short i = 0; i < shape.y - 1; ++i) {
-            input_data = vload4(i*shape.z*shape.w, input+offset);;
-            input_data = EXP(input_data - float_max_value.x);
-            accum_result += input_data;
-        }
-        accum_result.x = accum_result.x + accum_result.y + accum_result.z + accum_result.w;
-
-        input_data = vload4((shape.y - 1)*shape.z*shape.w, input+offset);
-        input_data -= float_max_value.x;
-        if (remain_channels == 0) {
-            accum_result.x += EXP(input_data.w);
-            accum_result.x += EXP(input_data.z);
-            accum_result.x += EXP(input_data.y);
-            accum_result.x += EXP(input_data.x);
-        } else if (remain_channels == 1) {
-            accum_result.x += EXP(input_data.z);
-            accum_result.x += EXP(input_data.y);
-            accum_result.x += EXP(input_data.x);
-        } else if (remain_channels == 2) {
-            accum_result.x += EXP(input_data.y);
-            accum_result.x += EXP(input_data.x);
-        } else if (remain_channels == 3) {
-            accum_result.x += EXP(input_data.x);
-        }
-
-        for(int i = 0; i < shape.y; ++i){
-            input_data = vload4(i*shape.z*shape.w, input+offset) - float_max_value.x;
-            input_data = EXP(input_data) / accum_result.x;
-            vstore4(input_data, i*shape.z*shape.w, output+offset);
-        }
+    const int x = get_global_id(0);
+    const int w = get_global_id(1);
+    const int bh = get_global_id(2);
+    DEAL_NON_UNIFORM_DIM3(x, w, bh);
+    
+    const int batch_idx = bh / shape.z;
+    const int height_idx = bh % shape.z;
+    const int offset = (((batch_idx*shape.y+0)*shape.z+height_idx)*shape.w+w)*4;
+#if SOFTMAX_LOCAL_SIZE >= 4
+    int lid = get_local_id(0);
+    FLOAT4 local sum[SOFTMAX_LOCAL_SIZE];
+
+    FLOAT4 maxValue = (FLOAT4)-FLT_MAX;
+    for (int i = lid; i < shape.y - 1; i+=SOFTMAX_LOCAL_SIZE) {
+        maxValue = fmax(maxValue, vload4(i*shape.z*shape.w, input+offset));
+    }
+
+    sum[lid] = maxValue;
+    barrier(CLK_LOCAL_MEM_FENCE);
+    for(int i = SOFTMAX_LOCAL_SIZE/2; i > 0; i /= 2){
+        if (lid < i)
+            sum[lid] = fmax(sum[lid], sum[lid + i]);
+        barrier(CLK_LOCAL_MEM_FENCE);
+    }
+    maxValue = sum[0];
+
+    maxValue.x = fmax(maxValue.x, maxValue.y);
+    maxValue.x = fmax(maxValue.x, maxValue.z);
+    maxValue.x = fmax(maxValue.x, maxValue.w);
+
+    FLOAT4 input_data = vload4((shape.y - 1) *shape.z*shape.w, input+offset);
+    if (remain_channels == 0) {
+        maxValue.x = fmax(maxValue.x, input_data.x);
+        maxValue.x = fmax(maxValue.x, input_data.y);
+        maxValue.x = fmax(maxValue.x, input_data.z);
+        maxValue.x = fmax(maxValue.x, input_data.w);
+    } else if (remain_channels == 1) {
+        maxValue.x = fmax(maxValue.x, input_data.z);
+        maxValue.x = fmax(maxValue.x, input_data.y);
+        maxValue.x = fmax(maxValue.x, input_data.x);
+    } else if (remain_channels == 2) {
+        maxValue.x = fmax(maxValue.x, input_data.y);
+        maxValue.x = fmax(maxValue.x, input_data.x);
+    } else if (remain_channels == 3) {
+        maxValue.x = fmax(maxValue.x, input_data.x);
+    }
+
+    FLOAT4 sumValue = (FLOAT4)0;
+    for (int i = lid; i < shape.y - 1; i+=SOFTMAX_LOCAL_SIZE) {
+        sumValue += exp(vload4(i*shape.z*shape.w, input+offset) - (FLOAT4)maxValue.x);
+    }
+    sum[lid] = sumValue;
+    barrier(CLK_LOCAL_MEM_FENCE);
+    for(int i = SOFTMAX_LOCAL_SIZE/2; i > 0; i /= 2){
+        if (lid < i)
+            sum[lid] = sum[lid] + sum[lid + i];
+        barrier(CLK_LOCAL_MEM_FENCE);
+    }
+    sumValue = sum[0];
+    sumValue.x = sumValue.x + sumValue.y + sumValue.z + sumValue.w;
+    
+    
+    input_data -= maxValue.x;
+    if (remain_channels == 0) {
+        sumValue.x += exp(input_data.w);
+        sumValue.x += exp(input_data.z);
+        sumValue.x += exp(input_data.y);
+        sumValue.x += exp(input_data.x);
+    } else if (remain_channels == 1) {
+        sumValue.x += exp(input_data.z);
+        sumValue.x += exp(input_data.y);
+        sumValue.x += exp(input_data.x);
+    } else if (remain_channels == 2) {
+        sumValue.x += exp(input_data.y);
+        sumValue.x += exp(input_data.x);
+    } else if (remain_channels == 3) {
+        sumValue.x += exp(input_data.x);
+    }
+    for(int i = lid; i < shape.y; i+=SOFTMAX_LOCAL_SIZE){
+        FLOAT4 value = exp(vload4(i*shape.z*shape.w, input+offset) - maxValue.x) / sumValue.x;
+        vstore4(value, i*shape.z*shape.w, output+offset);
+    }
+#else
+    FLOAT4 maxValue = (FLOAT4)-FLT_MAX;
+    for (int i = 0; i < shape.y - 1; i++) {
+        maxValue = fmax(maxValue, vload4(i*shape.z*shape.w, input+offset));
+    }
+    
+    maxValue.x = fmax(maxValue.x, maxValue.y);
+    maxValue.x = fmax(maxValue.x, maxValue.z);
+    maxValue.x = fmax(maxValue.x, maxValue.w);
+
+    FLOAT4 input_data = vload4((shape.y - 1) *shape.z*shape.w, input+offset);
+    if (remain_channels == 0) {
+        maxValue.x = fmax(maxValue.x, input_data.x);
+        maxValue.x = fmax(maxValue.x, input_data.y);
+        maxValue.x = fmax(maxValue.x, input_data.z);
+        maxValue.x = fmax(maxValue.x, input_data.w);
+    } else if (remain_channels == 1) {
+        maxValue.x = fmax(maxValue.x, input_data.z);
+        maxValue.x = fmax(maxValue.x, input_data.y);
+        maxValue.x = fmax(maxValue.x, input_data.x);
+    } else if (remain_channels == 2) {
+        maxValue.x = fmax(maxValue.x, input_data.y);
+        maxValue.x = fmax(maxValue.x, input_data.x);
+    } else if (remain_channels == 3) {
+        maxValue.x = fmax(maxValue.x, input_data.x);
     }
+
+    FLOAT4 sumValue = (FLOAT4)0;
+    for (int i = 0; i < shape.y - 1; i++) {
+        sumValue += exp(vload4(i*shape.z*shape.w, input+offset) - (FLOAT4)maxValue.x);
+    }
+    sumValue.x = sumValue.x + sumValue.y + sumValue.z + sumValue.w;
+    input_data -= maxValue.x;
+    if (remain_channels == 0) {
+        sumValue.x += exp(input_data.w);
+        sumValue.x += exp(input_data.z);
+        sumValue.x += exp(input_data.y);
+        sumValue.x += exp(input_data.x);
+    } else if (remain_channels == 1) {
+        sumValue.x += exp(input_data.z);
+        sumValue.x += exp(input_data.y);
+        sumValue.x += exp(input_data.x);
+    } else if (remain_channels == 2) {
+        sumValue.x += exp(input_data.y);
+        sumValue.x += exp(input_data.x);
+    } else if (remain_channels == 3) {
+        sumValue.x += exp(input_data.x);
+    }
+    for(int i = 0; i < shape.y; i++){
+        FLOAT4 value = exp(vload4(i*shape.z*shape.w, input+offset) - maxValue.x) / sumValue.x;
+        vstore4(value, i*shape.z*shape.w, output+offset);
+    }
+#endif
 }
 
 
-__kernel void softmax_height(__global const FLOAT *input,
+__kernel void softmax_height(GLOBAL_SIZE_3_DIMS
+                             __global const FLOAT *input,
                              __global FLOAT *output,
+                             __private const int remain_channels,
                              __private const int4 shape // NCHW
                              ) {
-    int wc = get_global_id(0);
-    int b = get_global_id(1);
+    const int x = get_global_id(0);
+    const int wc = get_global_id(1);
+    const int b = get_global_id(2);
+    DEAL_NON_UNIFORM_DIM3(x, wc, b);
     
     const int c = wc / shape.w;
     const int w = wc % shape.w;
     const int offset = (((b*shape.y+c)*shape.z+0)*shape.w+w)*4;
+#if SOFTMAX_LOCAL_SIZE >= 4
+    int lid = get_local_id(0);
+    FLOAT4 local sum[SOFTMAX_LOCAL_SIZE];
+    
+    /*Compute Max */
+    FLOAT4 maxValue = (FLOAT4)(-FLT_MAX);
+    for (int i=lid; i<shape.z; i+=SOFTMAX_LOCAL_SIZE) {
+        maxValue = fmax(maxValue, vload4(i*shape.w, input+offset));
+    }
+    sum[lid] = maxValue;
+    barrier(CLK_LOCAL_MEM_FENCE);
+    for(int i = SOFTMAX_LOCAL_SIZE/2; i > 0; i /= 2){
+        if (lid < i)
+            sum[lid] = fmax(sum[lid], sum[lid + i]);
+        barrier(CLK_LOCAL_MEM_FENCE);
+    }
+    maxValue = sum[0];
+    
+    /*Compute Exp Sum*/
+    FLOAT4 sumValue = (FLOAT4)0;
+    for (int i=lid; i<shape.z; i+=SOFTMAX_LOCAL_SIZE) {
+        sumValue += exp(vload4(i*shape.w, input+offset) - maxValue);
+    }
+    sum[lid] = sumValue;
+    barrier(CLK_LOCAL_MEM_FENCE);
+    for(int i = SOFTMAX_LOCAL_SIZE/2; i > 0; i /= 2){
+        if (lid < i)
+            sum[lid] = sum[lid] + sum[lid + i];
+        barrier(CLK_LOCAL_MEM_FENCE);
+    }
+    sumValue = sum[0];
+
+    /*Compute Result */
+    for (int i=lid; i<shape.z; i+=SOFTMAX_LOCAL_SIZE) {
+        FLOAT4 value = exp(vload4(i*shape.w, input+offset) - maxValue) / sumValue;
+        vstore4(value, i*shape.w, output+offset);
+    }
+#else
+    /*Compute Max */
+    FLOAT4 maxValue = (FLOAT4)(-FLT_MAX);
+    for (int i=0; i<shape.z; i++) {
+        maxValue = fmax(maxValue, vload4(i*shape.w, input+offset));
+    }
     
-    if (wc < shape.y*shape.w && b < shape.x) {
-        /*Compute Max */
-        FLOAT4 maxValue = vload4(0, input+offset);
-        for (int i=1; i<shape.z; ++i) {
-            maxValue = fmax(maxValue, vload4(i*shape.w, input+offset));
-        }
-        /*Compute Exp Sum*/
-        FLOAT4 sumValue = (FLOAT4)0;
-        for (int i=0; i<shape.z; ++i) {
-            sumValue += exp(vload4(i*shape.w, input+offset) - maxValue);
-        }
-        /*Compute Result */
-        for (int i=0; i<shape.z; ++i) {
-            FLOAT4 value = exp(vload4(i*shape.w, input+offset) - maxValue) / sumValue;
-            vstore4(value, i*shape.w, output+offset);
-        }
-    }    
+    /*Compute Exp Sum*/
+    FLOAT4 sumValue = (FLOAT4)0;
+    for (int i=0; i<shape.z; i++) {
+        sumValue += exp(vload4(i*shape.w, input+offset) - maxValue);
+    }
+
+    /*Compute Result */
+    for (int i=0; i<shape.z; i++) {
+        FLOAT4 value = exp(vload4(i*shape.w, input+offset) - maxValue) / sumValue;
+        vstore4(value, i*shape.w, output+offset);
+    }
+#endif
 }
 
 
-__kernel void softmax_width(__global const FLOAT *input,
+__kernel void softmax_width(GLOBAL_SIZE_3_DIMS
+                            __global const FLOAT *input,
                             __global FLOAT *output,
+                            __private const int remain_channels,
                             __private const int4 shape // NCHW
                             ) {
-    int c = get_global_id(0);
-    int bh = get_global_id(1);
-    
+    const int x = get_global_id(0);
+    const int c = get_global_id(1);
+    const int bh = get_global_id(2);
+    DEAL_NON_UNIFORM_DIM3(x, c, bh);
     const int b = bh / shape.z;
     const int h = bh % shape.z;
     const int offset = (((b*shape.y+c)*shape.z+h)*shape.w+0)*4;
+#if SOFTMAX_LOCAL_SIZE >= 4
+    int lid = get_local_id(0);
+    FLOAT4 local sum[SOFTMAX_LOCAL_SIZE];
+    
+    /*Compute Max */
+    FLOAT4 maxValue = (FLOAT4)(-FLT_MAX);
+    for (int i=lid; i<shape.w; i+=SOFTMAX_LOCAL_SIZE) {
+        maxValue = fmax(maxValue, vload4(i, input+offset));
+    }
+    sum[lid] = maxValue;
+    barrier(CLK_LOCAL_MEM_FENCE);
+    for(int i = SOFTMAX_LOCAL_SIZE/2; i > 0; i /= 2){
+        if (lid < i)
+            sum[lid] = fmax(sum[lid], sum[lid + i]);
+        barrier(CLK_LOCAL_MEM_FENCE);
+    }
+    maxValue = sum[0];
+    
+    /*Compute Exp Sum*/
+    FLOAT4 sumValue = (FLOAT4)0;
+    for (int i=lid; i<shape.z; i+=SOFTMAX_LOCAL_SIZE) {
+        sumValue += exp(vload4(i, input+offset) - maxValue);
+    }
+    sum[lid] = sumValue;
+    barrier(CLK_LOCAL_MEM_FENCE);
+    for(int i = SOFTMAX_LOCAL_SIZE/2; i > 0; i /= 2){
+        if (lid < i)
+            sum[lid] = sum[lid] + sum[lid + i];
+        barrier(CLK_LOCAL_MEM_FENCE);
+    }
+    sumValue = sum[0];
+    
+    /*Compute Result */
+    for (int i=lid; i<shape.w; i+=SOFTMAX_LOCAL_SIZE) {
+        FLOAT4 value = exp(vload4(i, input+offset) - maxValue) / sumValue;
+        vstore4(value, i, output+offset);
+    }
+#else
+    /*Compute Max */
+    FLOAT4 maxValue = (FLOAT4)(-FLT_MAX);
+    for (int i=0; i<shape.w; i++) {
+        maxValue = fmax(maxValue, vload4(i, input+offset));
+    }
+    /*Compute Exp Sum*/
+    FLOAT4 sumValue = (FLOAT4)0;
+    for (int i=0; i<shape.z; i++) {
+        sumValue += exp(vload4(i, input+offset) - maxValue);
+    }
     
-    if (c < shape.y && bh < shape.x*shape.z) {
-        /*Compute Max */
-        FLOAT4 maxValue = vload4(0, input+offset);
-        for (int i=1; i<shape.w; ++i) {
-            maxValue = fmax(maxValue, vload4(i, input+offset));
-        }
-        /*Compute Exp Sum*/
-        FLOAT4 sumValue = (FLOAT4)0;
-        for (int i=0; i<shape.w; ++i) {
-            sumValue += exp(vload4(i, input+offset) - maxValue);
-        }
-        /*Compute Result */
-        for (int i=0; i<shape.w; ++i) {
-            FLOAT4 value = exp(vload4(i, input+offset) - maxValue) / sumValue;
-            vstore4(value, i, output+offset);
-        }
+    /*Compute Result */
+    for (int i=0; i<shape.w; i++) {
+        FLOAT4 value = exp(vload4(i, input+offset) - maxValue) / sumValue;
+        vstore4(value, i, output+offset);
     }
+#endif
 }

+ 1 - 1
source/backend/opencl/execution/image/ConvExecution.cpp

@@ -85,7 +85,7 @@ ConvExecution::ConvExecution(const std::vector<Tensor *> &inputs, const std::vec
 
     std::shared_ptr<MNN::ConvolutionCommon::Int8Common> quanCommon;
     if (nullptr != conv2dParams->quanParameter()) {
-        quanCommon = ConvolutionCommon::load(conv2dParams->quanParameter(), true);
+        quanCommon = ConvolutionCommon::load(conv2dParams, backend, true);
         if (nullptr == quanCommon) {
             MNN_ERROR("Memory not Enough, can't extract IDST Convolution: %s \n", op->name()->c_str());
         }

+ 1 - 1
source/backend/opencl/execution/image/ConvWinograd.cpp

@@ -69,7 +69,7 @@ ConvWinograd::ConvWinograd(const MNN::Convolution2D* op, Backend* backend) : Exe
 
     std::shared_ptr<MNN::ConvolutionCommon::Int8Common> quanCommon;
     if (nullptr != op->quanParameter()) {
-        quanCommon = ConvolutionCommon::load(op->quanParameter(), true);
+        quanCommon = ConvolutionCommon::load(op, backend, true);
         if (nullptr == quanCommon) {
             MNN_ERROR("Memory not Enough, can't extract IDST Convolution \n");
         }

+ 1 - 1
source/backend/opencl/execution/image/DeconvExecution.cpp

@@ -33,7 +33,7 @@ DeconvExecution::DeconvExecution(const std::vector<Tensor *> &inputs, const MNN:
     const float* filterDataPtr = nullptr;
     int weightSize = 0;
     std::shared_ptr<ConvolutionCommon::Int8Common> quanCommon;
-    ConvolutionCommon::getConvParameters(&quanCommon, conv2dParams, &filterDataPtr, &weightSize);
+    ConvolutionCommon::getConvParameters(&quanCommon, backend, conv2dParams, &filterDataPtr, &weightSize);
 
     int inputChannel  = weightSize / (kernelWidth * kernelHeight * outputChannel);
     std::vector<int> filterShape{outputChannel, inputChannel, kernelHeight, kernelWidth};

+ 1 - 1
source/backend/opencl/execution/image/DepthwiseConvExecution.cpp

@@ -37,7 +37,7 @@ DepthwiseConvExecution::DepthwiseConvExecution(const std::vector<Tensor *> &inpu
     const float* filterDataPtr = nullptr;
     int filterDataSize   = 0;
     std::shared_ptr<ConvolutionCommon::Int8Common> quanCommon;
-    ConvolutionCommon::getConvParameters(&quanCommon, mCon2dParams, &filterDataPtr, &filterDataSize);
+    ConvolutionCommon::getConvParameters(&quanCommon, backend, mCon2dParams, &filterDataPtr, &filterDataSize);
 
     mFilter.reset(Tensor::createDevice<float>({1, filterImageShape[1], 1, 4 * filterImageShape[0]}));
     std::shared_ptr<Tensor> filterBuffer(Tensor::createDevice<float>(filterShape));

+ 1 - 1
source/backend/opencl/execution/image/DepthwiseDeconvExecution.cpp

@@ -36,7 +36,7 @@ DepthwiseDeconvExecution::DepthwiseDeconvExecution(const std::vector<Tensor *> &
     const float* filterDataPtr = nullptr;
     int tempWeightSize   = 0;
     std::shared_ptr<ConvolutionCommon::Int8Common> quanCommon;
-    ConvolutionCommon::getConvParameters(&quanCommon, mCon2dParams, &filterDataPtr, &tempWeightSize);
+    ConvolutionCommon::getConvParameters(&quanCommon, backend, mCon2dParams, &filterDataPtr, &tempWeightSize);
 
     mFilter.reset(Tensor::createDevice<float>({1, filterImageShape[1], 1, 4 * filterImageShape[0]}));
     std::shared_ptr<Tensor> filterBuffer(Tensor::createDevice<float>(filterShape));

+ 1 - 1
source/backend/opencl/execution/image/EltwiseExecution.cpp

@@ -207,7 +207,7 @@ public:
                 case BinaryOpOperation_SquaredDifference:
                     return new EltwiseExecution(inputs, "(in0-in1)*(in0-in1)", op, backend);
                 case BinaryOpOperation_ATAN2:
-                    return new EltwiseExecution(inputs, "atan(sign(in1)*in0/(fabs(in1)>(FLOAT4)((FLOAT)0.0000001)?fabs(in1):(FLOAT4)((FLOAT)0.0000001)))", op, backend);
+                    return new EltwiseExecution(inputs, "(in1==(FLOAT4)0?(sign(in0)*(FLOAT4)(PI/2)):(atan(in0/in1)+(in1>(FLOAT4)0?(FLOAT4)0:sign(in0)*(FLOAT4)PI)))", op, backend);
                 case BinaryOpOperation_NOTEQUAL:
                     return new EltwiseExecution(inputs, "convert_float4(-isnotequal(in0,in1))", op, backend);
                 case BinaryOpOperation_MOD:

+ 121 - 2
source/backend/opencl/execution/image/LoopExecution.cpp

@@ -15,7 +15,11 @@ namespace OpenCL {
 
 static void _TileTensor(Tensor *input, cl::Buffer *output, cl::Kernel& kernel, cl::NDRange &globalWorkSize,
                         cl::NDRange &localWorkSize, const int Width, const int Height, const int Channel,
-                        const int Batch, OpenCLRuntime *runTime, const std::set<std::string> &buildOptions) {
+                        const int Batch, OpenCLRuntime *runTime, std::set<std::string> buildOptions) {
+    
+    if (TensorUtils::getDescribe(input)->dimensionFormat == MNN::MNN_DATA_FORMAT_NHWC){
+        buildOptions.emplace("-DMNN_NHWC");
+    }
     kernel = runTime->buildKernel("loop", "tile", buildOptions);
     uint32_t mMaxWorkGroupSize  = static_cast<uint32_t>(runTime->getMaxWorkGroupSize(kernel));
     std::vector<uint32_t> mGlobalWorkSize = {(uint32_t)(Width * Height), (uint32_t)(UP_DIV(Channel, 4)), (uint32_t)(Batch)};
@@ -42,7 +46,10 @@ static void _TileTensor(Tensor *input, cl::Buffer *output, cl::Kernel& kernel, c
 
 static void _PackTensor(cl::Buffer *input, Tensor *output, cl::Kernel& kernel, cl::NDRange &globalWorkSize,
                         cl::NDRange &localWorkSize, const int Width, const int Height, const int Channel,
-                        const int Batch, OpenCLRuntime *runTime, const std::set<std::string> &buildOptions) {
+                        const int Batch, OpenCLRuntime *runTime, std::set<std::string> buildOptions) {
+    if (TensorUtils::getDescribe(output)->dimensionFormat == MNN::MNN_DATA_FORMAT_NHWC){
+        buildOptions.emplace("-DMNN_NHWC");
+    }
     kernel = runTime->buildKernel("loop", "pack", buildOptions);
     uint32_t mMaxWorkGroupSize  = static_cast<uint32_t>(runTime->getMaxWorkGroupSize(kernel));
     std::vector<uint32_t> mGlobalWorkSize = {(uint32_t)(Width * Height), (uint32_t)(UP_DIV(Channel, 4)), (uint32_t)(Batch)};
@@ -353,6 +360,75 @@ ErrorCode LoopBatchMatMulExecution::onResize(const std::vector<Tensor *> &inputs
     return NO_ERROR;
 }
 
+LoopBinaryExecution::LoopBinaryExecution(const LoopParam *loop, const std::string &compute, const MNN::Op *op, Backend *bn)
+    : CommonExecution(bn, op) {
+    mLoop = loop;
+    mTensors.resize(mLoop->tensorNumber());
+    auto cmd = loop->commands()->GetAs<RegionCommand>(0);
+    mBuildOptions.emplace("-DLOOP_BINARY_OPERATOR=" + compute);
+}
+ErrorCode LoopBinaryExecution::onResize(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) {
+    auto cmd                      = mLoop->commands()->GetAs<RegionCommand>(0);
+    OpenCLBackend *mOpenCLBackend = (OpenCLBackend *)backend();
+    auto runTime                  = mOpenCLBackend->getOpenCLRuntime();
+    startRecord(runTime, mRecording);
+    _setTensorStack(mTensors, inputs, outputs, mLoop);
+    mUnits.clear();
+    Unit unit;
+    auto input0 = mTensors[cmd->indexes()->data()[1]];
+    std::vector<int> Input0Shape = tensorShapeFormat(input0);
+    int Input0Size[4] = {Input0Shape.at(2), Input0Shape.at(1),Input0Shape.at(3),Input0Shape.at(0)};
+         
+    auto input1 = mTensors[cmd->indexes()->data()[2]];
+    std::vector<int> Input1Shape = tensorShapeFormat(input1);
+    int Input1Size[4] = {Input1Shape.at(2), Input1Shape.at(1),Input1Shape.at(3),Input1Shape.at(0)};
+         
+    auto output = mTensors[cmd->indexes()->data()[0]];
+    std::vector<int> Shape = tensorShapeFormat(output);
+    const int Channel = Shape.at(3);
+    const int Width = Shape.at(2);
+    const int Height = Shape.at(1);
+    const int Batch = Shape.at(0);
+    const int ChannelBlock = UP_DIV(Channel, 4);
+    auto BuildOptions = mBuildOptions;
+    if(Input0Size[2] != Input1Size[2]){
+        BuildOptions.emplace("-DBROADCAST_CHANNEL");
+    }
+    std::string KernelName = "broadcast_binary";
+    unit.kernel = runTime->buildKernel("loop", KernelName, BuildOptions);
+    uint32_t mMaxWorkGroupSize = static_cast<uint32_t>(runTime->getMaxWorkGroupSize(unit.kernel));
+        
+       
+    std::vector<uint32_t> mGlobalWorkSize = {(uint32_t)(Width), (uint32_t)(Height), (uint32_t)(Batch * ChannelBlock)};
+
+    uint32_t index = 0;
+    cl_int ret = CL_SUCCESS;
+    ret |= unit.kernel.setArg(index++, mGlobalWorkSize[0]);
+    ret |= unit.kernel.setArg(index++, mGlobalWorkSize[1]);
+    ret |= unit.kernel.setArg(index++, mGlobalWorkSize[2]);
+    ret |= unit.kernel.setArg(index++, openCLImage(output));
+    ret |= unit.kernel.setArg(index++, openCLImage(input0));
+    ret |= unit.kernel.setArg(index++, openCLImage(input1));
+    ret |= unit.kernel.setArg(index++, sizeof(Input0Size), Input0Size);
+    ret |= unit.kernel.setArg(index++, sizeof(Input1Size), Input1Size);
+    ret |= unit.kernel.setArg(index++, Width);
+    ret |= unit.kernel.setArg(index++, Height);
+    ret |= unit.kernel.setArg(index++, ChannelBlock);
+    MNN_CHECK_CL_SUCCESS(ret, "setArg LoopBinaryExecution");
+
+    std::vector<uint32_t> mLocalWorkSize = localWS3DDefault(mGlobalWorkSize, mMaxWorkGroupSize, runTime, KernelName, unit.kernel).first;
+
+    unit.globalWorkSize = {mGlobalWorkSize[0], mGlobalWorkSize[1], mGlobalWorkSize[2]};
+    unit.localWorkSize  = {mLocalWorkSize[0], mLocalWorkSize[1], mLocalWorkSize[2]};
+    recordKernel3d(unit.kernel, mGlobalWorkSize, mLocalWorkSize, runTime);
+    mUnits.emplace_back(unit);
+
+    endRecord(runTime, mRecording);
+
+    return NO_ERROR;
+}
+
+
 class LoopCreator : public OpenCLBackend::Creator {
 public:
     virtual Execution *onCreate(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs,
@@ -374,6 +450,49 @@ public:
             if (OpType_MatMul == subop->type() && loop->parallel()) {
                 return new LoopBatchMatMulExecution(loop, op, backend);
             }
+            if (OpType_BinaryOp == subop->type() && loop->parallel()) {
+                switch (subop->main_as_BinaryOp()->opType()) {
+                    case BinaryOpOperation_MUL:
+                        return new LoopBinaryExecution(loop, "in0*in1", op, backend);
+                    case BinaryOpOperation_ADD:
+                        return new LoopBinaryExecution(loop, "in0+in1", op, backend);
+                    case BinaryOpOperation_SUB:
+                        return new LoopBinaryExecution(loop, "in0-in1", op, backend);
+                    case BinaryOpOperation_REALDIV:
+                        return new LoopBinaryExecution(loop, "sign(in1)*in0/(fabs(in1)>(FLOAT4)((FLOAT)0.0000001)?fabs(in1):(FLOAT4)((FLOAT)0.0000001))", op, backend);
+                    case BinaryOpOperation_MINIMUM:
+                        return new LoopBinaryExecution(loop, "in0>in1?in1:in0", op, backend);
+                    case BinaryOpOperation_MAXIMUM:
+                        return new LoopBinaryExecution(loop, "in0>in1?in0:in1", op, backend);
+                    case BinaryOpOperation_GREATER:
+                        return new LoopBinaryExecution(loop, "convert_float4(-isgreater(in0,in1))", op, backend);
+                    case BinaryOpOperation_LESS:
+                        return new LoopBinaryExecution(loop, "convert_float4(-isless(in0,in1))", op, backend);
+                    case BinaryOpOperation_LESS_EQUAL:
+                        return new LoopBinaryExecution(loop, "convert_float4(-islessequal(in0,in1))", op, backend);
+                    case BinaryOpOperation_GREATER_EQUAL:
+                        return new LoopBinaryExecution(loop, "convert_float4(-isgreaterequal(in0,in1))", op, backend);
+                    case BinaryOpOperation_EQUAL:
+                        return new LoopBinaryExecution(loop, "convert_float4(-isequal(in0,in1))", op, backend);
+                    case BinaryOpOperation_FLOORDIV:
+                        return new LoopBinaryExecution(loop, "floor(sign(in1)*in0/(fabs(in1)>(FLOAT4)((FLOAT)0.0000001)?fabs(in1):(FLOAT4)((FLOAT)0.0000001)))", op, backend);
+                    case BinaryOpOperation_FLOORMOD:
+                        return new LoopBinaryExecution(loop, "in0-floor(sign(in1)*in0/(fabs(in1)>(FLOAT4)((FLOAT)0.0000001)?fabs(in1):(FLOAT4)((FLOAT)0.0000001)))*in1", op, backend);
+                    case BinaryOpOperation_POW:
+                        return new LoopBinaryExecution(loop, "pow(in0,in1)", op, backend);
+                    case BinaryOpOperation_SquaredDifference:
+                        return new LoopBinaryExecution(loop, "(in0-in1)*(in0-in1)", op, backend);
+                    case BinaryOpOperation_ATAN2:
+                        return new LoopBinaryExecution(loop, "atan(sign(in1)*in0/(fabs(in1)>(FLOAT4)((FLOAT)0.0000001)?fabs(in1):(FLOAT4)((FLOAT)0.0000001)))", op, backend);
+                    case BinaryOpOperation_NOTEQUAL:
+                        return new LoopBinaryExecution(loop, "convert_float4(-isnotequal(in0,in1))", op, backend);
+                    case BinaryOpOperation_MOD:
+                        return new LoopBinaryExecution(loop, "in0-floor(sign(in1)*in0/(fabs(in1)>(FLOAT4)((FLOAT)0.0000001)?fabs(in1):(FLOAT4)((FLOAT)0.0000001)))*in1", op, backend);
+                    default:
+                        break;
+                }
+                return nullptr;
+            }
         }
         return nullptr;
     }

+ 12 - 0
source/backend/opencl/execution/image/LoopExecution.hpp

@@ -53,6 +53,18 @@ private:
     std::set<std::string> mBuildOptions;
 };
 
+class LoopBinaryExecution : public CommonExecution {
+public:
+    LoopBinaryExecution(const LoopParam *loop, const std::string &compute, const MNN::Op *op, Backend *bn);
+    virtual ~LoopBinaryExecution() = default;
+    virtual ErrorCode onResize(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) override;
+
+private:
+    const LoopParam *mLoop;
+    std::vector<Tensor *> mTensors;
+    std::set<std::string> mBuildOptions;
+};
+
 } // namespace OpenCL
 } // namespace MNN
 #endif /* LoopExecution_hpp */

+ 37 - 43
source/backend/opencl/execution/image/SoftmaxExecution.cpp

@@ -19,10 +19,11 @@ SoftmaxExecution::SoftmaxExecution(const std::vector<Tensor *> &inputs, int axis
     mOpenCLBackend = static_cast<OpenCLBackend *>(backend);
 }
 
-bool SoftmaxExecution::buildSoftmaxKernel() {
+bool SoftmaxExecution::buildSoftmaxKernel(int localSize) {
     auto runtime = mOpenCLBackend->getOpenCLRuntime();
     if (mKernel.get() == nullptr) {
         std::set<std::string> buildOptions;
+        buildOptions.emplace("-DSOFTMAX_LOCAL_SIZE=" + std::to_string(localSize));
         std::string kernelName;
         if (mAxis == 1) {
             mKernel           = runtime->buildKernel("softmax", "softmax_channel", buildOptions);
@@ -37,6 +38,14 @@ bool SoftmaxExecution::buildSoftmaxKernel() {
     return true;
 }
 
+int SoftmaxExecution::getLocalSize(int size, int maxGroupSize){
+    int local_size = 1;
+    while(local_size * 2 <= maxGroupSize && local_size * 2 <= size){
+        local_size *= 2;
+    }
+    return local_size;
+}
+
 ErrorCode SoftmaxExecution::onResize(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) {
     startRecord(mOpenCLBackend->getOpenCLRuntime(), mRecording);
     Tensor *input  = inputs[0];
@@ -68,61 +77,46 @@ ErrorCode SoftmaxExecution::onResize(const std::vector<Tensor *> &inputs, const
 
     const int channelBlocks  = UP_DIV(outputChannels, 4);
     const int remainChannels = channelBlocks * 4 - outputChannels;
+    auto MaxWorkItems = mOpenCLBackend->getOpenCLRuntime()->getMaxWorkItemSizes();
+    int localSize = getLocalSize(channel, MaxWorkItems[0]);
+    if(localSize < 4){
+        localSize = 1;
+    }
     if(inputBatch == outside && channel == inputChannels && inside == inputWidth * inputHeight){
         mAxis = 1;
-    }else if(inputBatch * inputChannels == outside && channel == inputHeight && inside == inputHeight){
+        localSize = getLocalSize(channelBlocks, MaxWorkItems[0]);
+    }else if(inputBatch * inputChannels == outside && channel == inputHeight && inside == inputWidth){
         mAxis = 2;
     }else if(inputBatch * inputChannels * inputHeight == outside && channel == inputWidth && inside == 1){
         mAxis = 3;
     }
-    buildSoftmaxKernel();
+    buildSoftmaxKernel(localSize);
     
     cl_int ret = CL_SUCCESS;
+    int shape[] = {outputBatch, channelBlocks, outputHeight, outputWidth};
     if (mAxis == 1) {
-        mGlobalWorkSize = {static_cast<uint32_t>(outputWidth),
-            static_cast<uint32_t>(outputHeight * outputBatch), 1};
-        int shape[] = {outputBatch, channelBlocks, outputHeight, outputWidth};
-        
-        uint32_t idx    = 0;
-        ret |= mKernel.setArg(idx++, mGlobalWorkSize[0]);
-        ret |= mKernel.setArg(idx++, mGlobalWorkSize[1]);
-        ret |= mKernel.setArg(idx++, mGlobalWorkSize[2]);
-
-        ret |= mKernel.setArg(idx++, openCLImage(input));
-        ret |= mKernel.setArg(idx++, openCLImage(output));
-        ret |= mKernel.setArg(idx++, static_cast<int>(outputChannels));
-        ret |= mKernel.setArg(idx++, remainChannels);
-        ret |= mKernel.setArg(idx++, shape);
-        MNN_CHECK_CL_SUCCESS(ret, "setArg SoftmaxExecution Axis_1");
-
-        std::string kernelName = "softmax_channel";
-        mLocalWorkSize = localWS3DDefault(mGlobalWorkSize, mMaxWorkGroupSize, mOpenCLBackend->getOpenCLRuntime(), kernelName, mKernel).first;
+        mGlobalWorkSize = {(uint32_t)(localSize), (uint32_t)outputWidth, (uint32_t)outputHeight * outputBatch};
 
     } else if (mAxis == 2){
-        if (mMaxWorkGroupSize > 256) {
-            mLocalWorkSize = {16, 16, 1};
-        } else {
-            mLocalWorkSize = {8, 8, 1};
-        }
-        mGlobalWorkSize = {(uint32_t)channelBlocks*outputWidth, (uint32_t)outputBatch, 1};
-        int shape[] = {outputBatch, channelBlocks, outputHeight, outputWidth};
-        ret |= mKernel.setArg(0, openCLImage(input));
-        ret |= mKernel.setArg(1, openCLImage(output));
-        ret |= mKernel.setArg(2, shape);
-        MNN_CHECK_CL_SUCCESS(ret, "setArg SoftmaxExecution Axis_2");
+        mGlobalWorkSize = {(uint32_t)(localSize), (uint32_t)channelBlocks*outputWidth, (uint32_t)outputBatch};
     } else {
         MNN_ASSERT(mAxis == 3);
-        if (mMaxWorkGroupSize > 256) {
-            mLocalWorkSize = {16, 16, 1};
-        } else {
-            mLocalWorkSize = {8, 8, 1};
-        }
-        mGlobalWorkSize = {(uint32_t)channelBlocks, (uint32_t)outputBatch*outputHeight, 1};
-        int shape[] = {outputBatch, channelBlocks, outputHeight, outputWidth};
-        ret |= mKernel.setArg(0, openCLImage(input));
-        ret |= mKernel.setArg(1, openCLImage(output));
-        ret |= mKernel.setArg(2, shape);
-        MNN_CHECK_CL_SUCCESS(ret, "setArg SoftmaxExecution Axis_3");
+        mGlobalWorkSize = {(uint32_t)(localSize), (uint32_t)channelBlocks, (uint32_t)outputBatch*outputHeight};
+    }
+    mLocalWorkSize = {(uint32_t)(localSize), 1, 1};
+    
+    uint32_t idx    = 0;
+    ret |= mKernel.setArg(idx++, mGlobalWorkSize[0]);
+    ret |= mKernel.setArg(idx++, mGlobalWorkSize[1]);
+    ret |= mKernel.setArg(idx++, mGlobalWorkSize[2]);
+
+    ret |= mKernel.setArg(idx++, openCLImage(input));
+    ret |= mKernel.setArg(idx++, openCLImage(output));
+    ret |= mKernel.setArg(idx++, remainChannels);
+    ret |= mKernel.setArg(idx++, shape);
+    MNN_CHECK_CL_SUCCESS(ret, "setArg SoftmaxExecution");
+    if(localSize == 1){
+        mLocalWorkSize = localWS3DDefault(mGlobalWorkSize, mMaxWorkGroupSize, mOpenCLBackend->getOpenCLRuntime(), "softmax", mKernel).first;
     }
     recordKernel3d(mKernel, mGlobalWorkSize, mLocalWorkSize, mOpenCLBackend->getOpenCLRuntime());
     endRecord(mOpenCLBackend->getOpenCLRuntime(), mRecording);

+ 2 - 1
source/backend/opencl/execution/image/SoftmaxExecution.hpp

@@ -26,8 +26,9 @@ public:
     virtual ErrorCode onResize(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) override;
     virtual ErrorCode onExecute(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) override;
 
-    bool buildSoftmaxKernel();
+    bool buildSoftmaxKernel(int localSize);
 private:
+    int getLocalSize(int size, int maxGroupSize);
     cl::Kernel mKernel;
     uint32_t mMaxWorkGroupSize;
     OpenCLBackend *mOpenCLBackend;

+ 2 - 1
source/backend/tensorrt/backend/TRTBackend.cpp

@@ -356,7 +356,7 @@ void TRTBackend::onResizeBegin() {
     init();
 }
 
-void TRTBackend::onResizeEnd() {
+ErrorCode TRTBackend::onResizeEnd() {
 #ifdef TRT_LOG
     printf("\n\nTRTBackend onResizeEnd in\n");
 #endif
@@ -434,6 +434,7 @@ void TRTBackend::onResizeEnd() {
         delete l;
     }
     mEraseLayers.clear();
+    return NO_ERROR;
 }
 
 INetworkDefinition* TRTBackend::getNetwork() {

+ 2 - 1
source/backend/tensorrt/backend/TRTBackend.hpp

@@ -9,6 +9,7 @@
 #ifndef MNN_TRTBackend_H
 #define MNN_TRTBackend_H
 
+#include <MNN/ErrorCode.hpp>
 #include <core/Backend.hpp>
 #include <core/Execution.hpp>
 
@@ -88,7 +89,7 @@ public:
     virtual void onCopyBuffer(const Tensor* srcTensor, const Tensor* dstTensor) const override;
 
     virtual void onResizeBegin() override;
-    virtual void onResizeEnd() override;
+    virtual ErrorCode onResizeEnd() override;
 
     class Creator {
     public:

+ 1 - 1
source/backend/tensorrt/execution/TRTConvolution.cpp

@@ -34,7 +34,7 @@ std::vector<ITensor *> TRTConvolution::onEncode(const std::vector<ITensor *> &xO
     int weightSize      = 0;
     std::shared_ptr<ConvolutionCommon::Int8Common> quanWeight;
     if (nullptr != mOp->main_as_Convolution2D()->quanParameter()) {
-        quanWeight = ConvolutionCommon::load(mOp->main_as_Convolution2D()->quanParameter(), true);
+        quanWeight = ConvolutionCommon::load(mOp->main_as_Convolution2D(), backend(), true);
         srcCount   = quanWeight->weightFloat.size() / (outputCount * kernelX * kernelY);
         source     = quanWeight->weightFloat.get();
         weightSize = quanWeight->weightFloat.size();

+ 1 - 1
source/backend/tensorrt/execution/TRTDeconvolution.cpp

@@ -35,7 +35,7 @@ std::vector<ITensor *> TRTDeconvolution::onEncode(const std::vector<ITensor *> &
     int weightSize      = 0;
 
     std::shared_ptr<ConvolutionCommon::Int8Common> quanCommon;
-    ConvolutionCommon::getConvParameters(&quanCommon, conv2D, &source, &weightSize);
+    ConvolutionCommon::getConvParameters(&quanCommon, backend(), conv2D, &source, &weightSize);
 
     nvinfer1::DimsHW NVKSize(kernelY, kernelX);
     nvinfer1::DimsHW NVKSSize(conv2DCommon->strideY(), conv2DCommon->strideX());

+ 1 - 1
source/backend/tensorrt/execution/TRTDepthwiseConvolution.cpp

@@ -36,7 +36,7 @@ std::vector<ITensor *> TRTDepthwiseConvolution::onEncode(const std::vector<ITens
     int weightSize      = 0;
     std::shared_ptr<ConvolutionCommon::Int8Common> quanWeight;
     if (nullptr != mOp->main_as_Convolution2D()->quanParameter()) {
-        quanWeight = ConvolutionCommon::load(mOp->main_as_Convolution2D()->quanParameter(), true);
+        quanWeight = ConvolutionCommon::load(mOp->main_as_Convolution2D(), backend(), true);
         source     = quanWeight->weightFloat.get();
         weightSize = quanWeight->weightFloat.size();
     } else {

+ 1 - 1
source/backend/tensorrt/execution/TRTDepthwiseDeconvolution.cpp

@@ -35,7 +35,7 @@ std::vector<ITensor *> TRTDepthwiseDeconvolution::onEncode(const std::vector<ITe
     int weightSize      = 0;
     
     std::shared_ptr<ConvolutionCommon::Int8Common> quanCommon;
-    ConvolutionCommon::getConvParameters(&quanCommon, conv2D, &source, &weightSize);
+    ConvolutionCommon::getConvParameters(&quanCommon, backend(), conv2D, &source, &weightSize);
 
     nvinfer1::DimsHW NVKSize(kernelY, kernelX);
     nvinfer1::DimsHW NVKSSize(conv2DCommon->strideY(), conv2DCommon->strideX());

+ 2 - 1
source/backend/vulkan/buffer/backend/VulkanBackend.cpp

@@ -131,10 +131,11 @@ void VulkanBackend::onResizeBegin() {
         mCmdBuffer->begin(0);
     }
 }
-void VulkanBackend::onResizeEnd() {
+ErrorCode VulkanBackend::onResizeEnd() {
     if (!mDirect) {
         mCmdBuffer->end();
     }
+    return NO_ERROR;
 }
 class VulkanMemRelease : public Backend::MemObj {
 public:

+ 2 - 1
source/backend/vulkan/buffer/backend/VulkanBackend.hpp

@@ -10,6 +10,7 @@
 #define VulkanBackend_hpp
 
 #include <map>
+#include <MNN/ErrorCode.hpp>
 #include "MNN_generated.h"
 #include "VulkanRuntime.hpp"
 namespace MNN {
@@ -27,7 +28,7 @@ public:
     virtual void onExecuteBegin() const override;
     virtual void onExecuteEnd() const override;
     virtual void onResizeBegin() override;
-    virtual void onResizeEnd() override;
+    virtual ErrorCode onResizeEnd() override;
     virtual void onCopyBuffer(const Tensor* srcTensor, const Tensor* dstTensor) const override;
     virtual const Runtime* getRuntime() override {
         return mRuntime;

+ 1 - 1
source/backend/vulkan/buffer/execution/VulkanConvolution.cpp

@@ -289,7 +289,7 @@ public:
                     return nullptr;
                 }
             }
-            quanWeight = ConvolutionCommon::load(op->main_as_Convolution2D()->quanParameter(), true);
+            quanWeight = ConvolutionCommon::load(op->main_as_Convolution2D(), backend, true);
             srcCount = quanWeight->weightFloat.size() / (outputCount * fh * fw);
             source   = quanWeight->weightFloat.get();
             weightSize = quanWeight->weightFloat.size();

+ 1 - 1
source/backend/vulkan/buffer/execution/VulkanDeconvolution.cpp

@@ -45,7 +45,7 @@ VulkanDeconvolution* VulkanDeconvolution::create(Backend* bn, const Convolution2
     int tempWeightSize   = 0;
     std::shared_ptr<ConvolutionCommon::Int8Common> quanCommon;
     if (!multiInputs) {
-        ConvolutionCommon::getConvParameters(&quanCommon, conv, &tempWeight, &tempWeightSize);
+        ConvolutionCommon::getConvParameters(&quanCommon, bn, conv, &tempWeight, &tempWeightSize);
         MNN_ASSERT(nullptr != tempWeight);
         if (0 >= ci) {
             ci = tempWeightSize / co / kw / kh;

+ 2 - 1
source/backend/vulkan/image/backend/VulkanBackend.cpp

@@ -114,13 +114,14 @@ void VulkanBackend::onResizeBegin() {
         mCmdBuffer->begin(0);
     }
 }
-void VulkanBackend::onResizeEnd() {
+ErrorCode VulkanBackend::onResizeEnd() {
     if (!mDirect) {
         mCmdBuffer->end();
     }
     mInitBuffer->end();
     mCmdBuffers.emplace_back(mInitBuffer->get());
     _finish();
+    return NO_ERROR;
 }
 class VulkanMemRelease : public Backend::MemObj {
 public:

+ 2 - 1
source/backend/vulkan/image/backend/VulkanBackend.hpp

@@ -10,6 +10,7 @@
 #define VulkanBackend_hpp
 
 #include <map>
+#include <MNN/ErrorCode.hpp>
 #include "MNN_generated.h"
 #include "VulkanRuntime.hpp"
 #include "VulkanTensor.hpp"
@@ -31,7 +32,7 @@ public:
     virtual void onExecuteBegin() const override;
     virtual void onExecuteEnd() const override;
     virtual void onResizeBegin() override;
-    virtual void onResizeEnd() override;
+    virtual ErrorCode onResizeEnd() override;
     virtual void onCopyBuffer(const Tensor* srcTensor, const Tensor* dstTensor) const override;
 
     const VulkanPipeline* getPipeline(const std::string& key, const std::vector<VkDescriptorType>& types,

+ 0 - 0
source/backend/vulkan/image/execution/VulkanConvolution.cpp


Niektoré súbory nie sú zobrazené, pretože je v týchto rozdielových dátach zmenené mnoho súborov